openai_test.go 4.83 KB
Newer Older
1
2
3
package openai

import (
4
	"encoding/base64"
5
	"testing"
Michael Yang's avatar
lint  
Michael Yang committed
6

Grace's avatar
Grace committed
7
8
	"github.com/google/go-cmp/cmp"

Michael Yang's avatar
lint  
Michael Yang committed
9
	"github.com/ollama/ollama/api"
10
11
)

Michael Yang's avatar
lint  
Michael Yang committed
12
const (
13
14
	prefix = `data:image/jpeg;base64,`
	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
Michael Yang's avatar
lint  
Michael Yang committed
15
)
16

17
18
19
20
21
22
func TestFromChatRequest_Basic(t *testing.T) {
	req := ChatCompletionRequest{
		Model: "test-model",
		Messages: []Message{
			{Role: "user", Content: "Hello"},
		},
23
24
	}

25
26
27
	result, err := FromChatRequest(req)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
28
29
	}

30
31
	if result.Model != "test-model" {
		t.Errorf("expected model 'test-model', got %q", result.Model)
32
33
	}

34
35
	if len(result.Messages) != 1 {
		t.Fatalf("expected 1 message, got %d", len(result.Messages))
36
37
	}

38
39
	if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
		t.Errorf("unexpected message: %+v", result.Messages[0])
40
41
42
	}
}

43
44
func TestFromChatRequest_WithImage(t *testing.T) {
	imgData, _ := base64.StdEncoding.DecodeString(image)
45

46
47
48
49
50
51
52
53
54
55
56
	req := ChatCompletionRequest{
		Model: "test-model",
		Messages: []Message{
			{
				Role: "user",
				Content: []any{
					map[string]any{"type": "text", "text": "Hello"},
					map[string]any{
						"type":      "image_url",
						"image_url": map[string]any{"url": prefix + image},
					},
57
				},
58
59
60
			},
		},
	}
61

62
63
64
	result, err := FromChatRequest(req)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
65
	}
66

67
68
69
	if len(result.Messages) != 2 {
		t.Fatalf("expected 2 messages, got %d", len(result.Messages))
	}
70

71
72
73
	if result.Messages[0].Content != "Hello" {
		t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content)
	}
74

75
76
77
	if len(result.Messages[1].Images) != 1 {
		t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images))
	}
78

79
80
	if string(result.Messages[1].Images[0]) != string(imgData) {
		t.Error("image data mismatch")
81
82
83
	}
}

84
85
86
87
88
89
func TestFromCompleteRequest_Basic(t *testing.T) {
	temp := float32(0.8)
	req := CompletionRequest{
		Model:       "test-model",
		Prompt:      "Hello",
		Temperature: &temp,
90
91
	}

92
93
94
	result, err := FromCompleteRequest(req)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
95
	}
96

97
98
	if result.Model != "test-model" {
		t.Errorf("expected model 'test-model', got %q", result.Model)
royjhan's avatar
royjhan committed
99
	}
100

101
102
	if result.Prompt != "Hello" {
		t.Errorf("expected prompt 'Hello', got %q", result.Prompt)
royjhan's avatar
royjhan committed
103
	}
104

105
106
	if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 {
		t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"])
royjhan's avatar
royjhan committed
107
	}
108
}
royjhan's avatar
royjhan committed
109

110
111
112
113
114
func TestToUsage(t *testing.T) {
	resp := api.ChatResponse{
		Metrics: api.Metrics{
			PromptEvalCount: 10,
			EvalCount:       20,
115
116
		},
	}
117

118
	usage := ToUsage(resp)
119

120
121
	if usage.PromptTokens != 10 {
		t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens)
122
123
	}

124
125
	if usage.CompletionTokens != 20 {
		t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens)
126
127
	}

128
129
	if usage.TotalTokens != 30 {
		t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens)
130
	}
131
}
132

133
134
135
136
137
138
139
140
141
func TestNewError(t *testing.T) {
	tests := []struct {
		code int
		want string
	}{
		{400, "invalid_request_error"},
		{404, "not_found_error"},
		{500, "api_error"},
	}
142

143
144
145
146
	for _, tt := range tests {
		result := NewError(tt.code, "test message")
		if result.Error.Type != tt.want {
			t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want)
147
		}
148
149
		if result.Error.Message != "test message" {
			t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message")
150
		}
151
152
	}
}
Grace's avatar
Grace committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

func TestToToolCallsPreservesIDs(t *testing.T) {
	original := []api.ToolCall{
		{
			ID: "call_abc123",
			Function: api.ToolCallFunction{
				Index: 2,
				Name:  "get_weather",
				Arguments: api.ToolCallFunctionArguments{
					"location": "Seattle",
				},
			},
		},
		{
			ID: "call_def456",
			Function: api.ToolCallFunction{
				Index: 7,
				Name:  "get_time",
				Arguments: api.ToolCallFunctionArguments{
					"timezone": "UTC",
				},
			},
		},
	}

	toolCalls := make([]api.ToolCall, len(original))
	copy(toolCalls, original)
	got := ToToolCalls(toolCalls)

	if len(got) != len(original) {
		t.Fatalf("expected %d tool calls, got %d", len(original), len(got))
	}

	expected := []ToolCall{
		{
			ID:    "call_abc123",
			Type:  "function",
			Index: 2,
			Function: struct {
				Name      string `json:"name"`
				Arguments string `json:"arguments"`
			}{
				Name:      "get_weather",
				Arguments: `{"location":"Seattle"}`,
			},
		},
		{
			ID:    "call_def456",
			Type:  "function",
			Index: 7,
			Function: struct {
				Name      string `json:"name"`
				Arguments string `json:"arguments"`
			}{
				Name:      "get_time",
				Arguments: `{"timezone":"UTC"}`,
			},
		},
	}

	if diff := cmp.Diff(expected, got); diff != "" {
		t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
	}

	if diff := cmp.Diff(original, toolCalls); diff != "" {
		t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
	}
}