olmo3.go 3.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package renderers

import (
	"encoding/json"
	"fmt"
	"sort"
	"strings"

	"github.com/ollama/ollama/api"
)

const (
	olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
	olmo3NoFunctionsMessage   = "You do not currently have access to any functions. "
	olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
)

type Olmo3Renderer struct{}

func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
	var sb strings.Builder

	var systemMessage *api.Message
	filteredMessages := make([]api.Message, 0, len(messages))
	for i, message := range messages {
		if message.Role == "system" {
			if systemMessage == nil {
				systemMessage = &messages[i]
			}
			continue
		}
		filteredMessages = append(filteredMessages, message)
	}

	// Render system message
	if systemMessage != nil {
		// Custom system message - single newline after "system"
		sb.WriteString("<|im_start|>system\n")
		sb.WriteString(systemMessage.Content)

		if len(tools) > 0 {
			functionsJSON, err := marshalWithSpaces(tools)
			if err != nil {
				return "", err
			}
			sb.WriteString("<functions>")
			sb.WriteString(string(functionsJSON))
			sb.WriteString("</functions>")
		}
		sb.WriteString("<|im_end|>\n")
	} else {
		// Default system message - single newline after "system"
		sb.WriteString("<|im_start|>system\n")
		sb.WriteString(olmo3DefaultSystemMessage)

		if len(tools) > 0 {
			functionsJSON, err := marshalWithSpaces(tools)
			if err != nil {
				return "", err
			}
			sb.WriteString(olmo3WithFunctionsMessage)
			sb.WriteString("<functions>")
			sb.WriteString(string(functionsJSON))
			sb.WriteString("</functions>")
		} else {
			sb.WriteString(olmo3NoFunctionsMessage)
			sb.WriteString("<functions></functions>")
		}
		sb.WriteString("<|im_end|>\n")
	}

	for i, message := range filteredMessages {
		lastMessage := i == len(filteredMessages)-1

		switch message.Role {
		case "user":
			sb.WriteString("<|im_start|>user\n")
			sb.WriteString(message.Content)
			sb.WriteString("<|im_end|>\n")

		case "assistant":
			sb.WriteString("<|im_start|>assistant\n")

			if message.Content != "" {
				sb.WriteString(message.Content)
			}

			if len(message.ToolCalls) > 0 {
				sb.WriteString("<function_calls>")
				for j, tc := range message.ToolCalls {
					// Format as function_name(arg1="value1", arg2="value2")
					sb.WriteString(tc.Function.Name)
					sb.WriteString("(")

					// Get sorted keys for deterministic output
					keys := make([]string, 0, len(tc.Function.Arguments))
					for k := range tc.Function.Arguments {
						keys = append(keys, k)
					}
					sort.Strings(keys)

					for k, key := range keys {
						if k > 0 {
							sb.WriteString(", ")
						}
						value, err := json.Marshal(tc.Function.Arguments[key])
						if err != nil {
							return "", err
						}
						sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
					}
					sb.WriteString(")")

					if j < len(message.ToolCalls)-1 {
						sb.WriteString("\n")
					}
				}
				sb.WriteString("</function_calls>")
			}

			// Add end tag unless it's the last message with content only (prefill)
			if !lastMessage || len(message.ToolCalls) > 0 {
				sb.WriteString("<|im_end|>\n")
			}

		case "tool":
			sb.WriteString("<|im_start|>environment\n")
			sb.WriteString(message.Content)
			sb.WriteString("<|im_end|>\n")
		}
	}

	// Add generation prompt if needed
	needsGenerationPrompt := true
	if len(filteredMessages) > 0 {
		lastMsg := filteredMessages[len(filteredMessages)-1]
		if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
			needsGenerationPrompt = false
		}
	}

	if needsGenerationPrompt {
		sb.WriteString("<|im_start|>assistant\n\n")
	}

	return sb.String(), nil
}