olmo3_think_test.go 5.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
package renderers

import (
	"testing"

	"github.com/google/go-cmp/cmp"

	"github.com/ollama/ollama/api"
)

func TestOlmo3ThinkRenderer(t *testing.T) {
	tests := []struct {
		name     string
14
		variant  Olmo3ThinkVariant
15
16
17
18
19
		msgs     []api.Message
		tools    []api.Tool
		expected string
	}{
		{
20
21
			name:    "7b_basic_without_system",
			variant: Olmo31Think,
22
23
24
25
			msgs: []api.Message{
				{Role: "user", Content: "Hello!"},
			},
			expected: "<|im_start|>system\n" +
26
				"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
27
28
29
30
31
32
				"<|im_start|>user\n" +
				"Hello!<|im_end|>\n" +
				"<|im_start|>assistant\n" +
				"<think>",
		},
		{
33
34
			name:    "7b_with_custom_system",
			variant: Olmo31Think,
35
36
37
38
39
40
41
42
43
44
45
46
			msgs: []api.Message{
				{Role: "system", Content: "You are a helpful assistant."},
				{Role: "user", Content: "Hello!"},
			},
			expected: "<|im_start|>system\n" +
				"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
				"<|im_start|>user\n" +
				"Hello!<|im_end|>\n" +
				"<|im_start|>assistant\n" +
				"<think>",
		},
		{
47
48
			name:    "7b_tools_ignored",
			variant: Olmo31Think,
49
50
51
52
53
54
55
56
57
58
59
60
61
			msgs: []api.Message{
				{Role: "user", Content: "What is the weather?"},
			},
			tools: []api.Tool{
				{
					Type: "function",
					Function: api.ToolFunction{
						Name:        "get_weather",
						Description: "Get the current weather",
					},
				},
			},
			expected: "<|im_start|>system\n" +
62
				"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
63
64
65
66
67
68
				"<|im_start|>user\n" +
				"What is the weather?<|im_end|>\n" +
				"<|im_start|>assistant\n" +
				"<think>",
		},
		{
69
70
			name:    "7b_tool_calls_and_tool_messages_ignored",
			variant: Olmo31Think,
71
72
73
74
75
76
77
78
79
			msgs: []api.Message{
				{Role: "user", Content: "What is the weather in SF?"},
				{
					Role:    "assistant",
					Content: "Let me check the weather.",
					ToolCalls: []api.ToolCall{
						{
							ID: "call_1",
							Function: api.ToolCallFunction{
80
								Name:      "get_weather",
81
								Arguments: testArgs(map[string]any{"location": "San Francisco"}),
82
83
84
85
							},
						},
					},
				},
86
				{Role: "tool", Content: `{"temperature": 68}`},
87
88
			},
			expected: "<|im_start|>system\n" +
89
				"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
90
91
92
				"<|im_start|>user\n" +
				"What is the weather in SF?<|im_end|>\n" +
				"<|im_start|>assistant\n" +
93
				"Let me check the weather.<|im_end|>\n" +
94
95
96
97
				"<|im_start|>assistant\n" +
				"<think>",
		},
		{
98
99
			name:    "7b_multi_turn_conversation",
			variant: Olmo31Think,
100
101
102
103
104
105
			msgs: []api.Message{
				{Role: "user", Content: "Hello"},
				{Role: "assistant", Content: "Hi there!"},
				{Role: "user", Content: "How are you?"},
			},
			expected: "<|im_start|>system\n" +
106
				"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
107
108
109
110
111
112
113
114
115
116
				"<|im_start|>user\n" +
				"Hello<|im_end|>\n" +
				"<|im_start|>assistant\n" +
				"Hi there!<|im_end|>\n" +
				"<|im_start|>user\n" +
				"How are you?<|im_end|>\n" +
				"<|im_start|>assistant\n" +
				"<think>",
		},
		{
117
118
			name:    "32b_basic_without_system",
			variant: Olmo3Think32B,
119
			msgs: []api.Message{
120
				{Role: "user", Content: "Hello!"},
121
122
			},
			expected: "<|im_start|>system\n" +
123
				"You are a helpful AI assistant.<|im_end|>\n" +
124
				"<|im_start|>user\n" +
125
				"Hello!<|im_end|>\n" +
126
				"<|im_start|>assistant\n" +
127
128
129
130
131
132
133
134
135
136
137
138
139
				"<think>",
		},
		{
			name:    "32b_with_custom_system_gets_suffix",
			variant: Olmo3Think32B,
			msgs: []api.Message{
				{Role: "system", Content: "You are a helpful assistant."},
				{Role: "user", Content: "Hello!"},
			},
			expected: "<|im_start|>system\n" +
				"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
				"<|im_start|>user\n" +
				"Hello!<|im_end|>\n" +
140
141
142
143
				"<|im_start|>assistant\n" +
				"<think>",
		},
		{
144
145
			name:    "31_basic_without_system",
			variant: Olmo31Think,
146
			msgs: []api.Message{
147
				{Role: "user", Content: "Hello!"},
148
149
			},
			expected: "<|im_start|>system\n" +
150
				"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
151
				"<|im_start|>user\n" +
152
				"Hello!<|im_end|>\n" +
153
				"<|im_start|>assistant\n" +
154
155
156
157
158
159
160
161
162
163
164
				"<think>",
		},
		{
			name:    "31_with_custom_system_gets_suffix",
			variant: Olmo31Think,
			msgs: []api.Message{
				{Role: "system", Content: "You are a helpful assistant."},
				{Role: "user", Content: "Hello!"},
			},
			expected: "<|im_start|>system\n" +
				"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
165
				"<|im_start|>user\n" +
166
				"Hello!<|im_end|>\n" +
167
168
169
170
171
172
173
				"<|im_start|>assistant\n" +
				"<think>",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
174
			rendered, err := (&Olmo3ThinkRenderer{Variant: tt.variant}).Render(tt.msgs, tt.tools, nil)
175
176
177
178
179
180
181
182
183
			if err != nil {
				t.Fatal(err)
			}
			if diff := cmp.Diff(rendered, tt.expected); diff != "" {
				t.Errorf("mismatch (-got +want):\n%s", diff)
			}
		})
	}
}