olmo3_think.go 3.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package renderers

import (
	"encoding/json"
	"strings"

	"github.com/ollama/ollama/api"
)

const (
	olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai."
	olmo3ThinkNoFunctionsMessage   = " You do not currently have access to any functions."
)

type Olmo3ThinkRenderer struct{}

type olmo3ThinkToolCall struct {
	ID       string                 `json:"id,omitempty"`
	Type     string                 `json:"type,omitempty"`
	Function olmo3ThinkToolCallFunc `json:"function"`
}

type olmo3ThinkToolCallFunc struct {
	Name      string `json:"name"`
	Arguments string `json:"arguments"`
}

func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
	var sb strings.Builder

	var systemMessage *api.Message
	filteredMessages := make([]api.Message, 0, len(messages))
	for i, message := range messages {
		if message.Role == "system" {
			if systemMessage == nil {
				systemMessage = &messages[i]
			}
			continue
		}
		filteredMessages = append(filteredMessages, message)
	}

	systemContent := olmo3ThinkDefaultSystemMessage
	if systemMessage != nil {
		systemContent = systemMessage.Content
	}

	sb.WriteString("<|im_start|>system\n")
	sb.WriteString(systemContent)

	if len(tools) > 0 {
		functionsJSON, err := marshalWithSpaces(tools)
		if err != nil {
			return "", err
		}
		sb.WriteString(" <functions>")
		sb.WriteString(string(functionsJSON))
		sb.WriteString("</functions>")
	} else {
		sb.WriteString(olmo3ThinkNoFunctionsMessage)
		sb.WriteString(" <functions></functions>")
	}
	sb.WriteString("<|im_end|>\n")

	for i, message := range filteredMessages {
		lastMessage := i == len(filteredMessages)-1

		switch message.Role {
		case "user":
			sb.WriteString("<|im_start|>user\n")
			sb.WriteString(message.Content)
			sb.WriteString("<|im_end|>\n")

		case "assistant":
			sb.WriteString("<|im_start|>assistant\n")

			if message.Content != "" {
				sb.WriteString(message.Content)
			}

			if len(message.ToolCalls) > 0 {
				toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls))
				for j, tc := range message.ToolCalls {
					argsJSON, err := json.Marshal(tc.Function.Arguments)
					if err != nil {
						return "", err
					}
					toolCalls[j] = olmo3ThinkToolCall{
						ID:   tc.ID,
						Type: "function",
						Function: olmo3ThinkToolCallFunc{
							Name:      tc.Function.Name,
							Arguments: string(argsJSON),
						},
					}
				}
				toolCallsJSON, err := marshalWithSpaces(toolCalls)
				if err != nil {
					return "", err
				}
				sb.WriteString("<function_calls>")
				sb.WriteString(string(toolCallsJSON))
				sb.WriteString("</function_calls>")
			}

			if !lastMessage {
				sb.WriteString("<|im_end|>\n")
			}

		case "tool":
			sb.WriteString("<|im_start|>environment\n")
			sb.WriteString(message.Content)
			sb.WriteString("<|im_end|>\n")
		}
	}

	needsGenerationPrompt := true
	if len(filteredMessages) > 0 {
		lastMsg := filteredMessages[len(filteredMessages)-1]
		if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
			needsGenerationPrompt = false
		}
	}

	if needsGenerationPrompt {
		sb.WriteString("<|im_start|>assistant\n<think>")
	}

	return sb.String(), nil
}