openai_test.go 3.4 KB
Newer Older
1
2
3
package openai

import (
4
	"encoding/base64"
5
	"testing"
Michael Yang's avatar
lint  
Michael Yang committed
6
7

	"github.com/ollama/ollama/api"
8
9
)

Michael Yang's avatar
lint  
Michael Yang committed
10
const (
11
12
	prefix = `data:image/jpeg;base64,`
	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
Michael Yang's avatar
lint  
Michael Yang committed
13
)
14

15
16
17
18
19
20
func TestFromChatRequest_Basic(t *testing.T) {
	req := ChatCompletionRequest{
		Model: "test-model",
		Messages: []Message{
			{Role: "user", Content: "Hello"},
		},
21
22
	}

23
24
25
	result, err := FromChatRequest(req)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
26
27
	}

28
29
	if result.Model != "test-model" {
		t.Errorf("expected model 'test-model', got %q", result.Model)
30
31
	}

32
33
	if len(result.Messages) != 1 {
		t.Fatalf("expected 1 message, got %d", len(result.Messages))
34
35
	}

36
37
	if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
		t.Errorf("unexpected message: %+v", result.Messages[0])
38
39
40
	}
}

41
42
func TestFromChatRequest_WithImage(t *testing.T) {
	imgData, _ := base64.StdEncoding.DecodeString(image)
43

44
45
46
47
48
49
50
51
52
53
54
	req := ChatCompletionRequest{
		Model: "test-model",
		Messages: []Message{
			{
				Role: "user",
				Content: []any{
					map[string]any{"type": "text", "text": "Hello"},
					map[string]any{
						"type":      "image_url",
						"image_url": map[string]any{"url": prefix + image},
					},
55
				},
56
57
58
			},
		},
	}
59

60
61
62
	result, err := FromChatRequest(req)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
63
	}
64

65
66
67
	if len(result.Messages) != 2 {
		t.Fatalf("expected 2 messages, got %d", len(result.Messages))
	}
68

69
70
71
	if result.Messages[0].Content != "Hello" {
		t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content)
	}
72

73
74
75
	if len(result.Messages[1].Images) != 1 {
		t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images))
	}
76

77
78
	if string(result.Messages[1].Images[0]) != string(imgData) {
		t.Error("image data mismatch")
79
80
81
	}
}

82
83
84
85
86
87
func TestFromCompleteRequest_Basic(t *testing.T) {
	temp := float32(0.8)
	req := CompletionRequest{
		Model:       "test-model",
		Prompt:      "Hello",
		Temperature: &temp,
88
89
	}

90
91
92
	result, err := FromCompleteRequest(req)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
93
	}
94

95
96
	if result.Model != "test-model" {
		t.Errorf("expected model 'test-model', got %q", result.Model)
royjhan's avatar
royjhan committed
97
	}
98

99
100
	if result.Prompt != "Hello" {
		t.Errorf("expected prompt 'Hello', got %q", result.Prompt)
royjhan's avatar
royjhan committed
101
	}
102

103
104
	if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 {
		t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"])
royjhan's avatar
royjhan committed
105
	}
106
}
royjhan's avatar
royjhan committed
107

108
109
110
111
112
func TestToUsage(t *testing.T) {
	resp := api.ChatResponse{
		Metrics: api.Metrics{
			PromptEvalCount: 10,
			EvalCount:       20,
113
114
		},
	}
115

116
	usage := ToUsage(resp)
117

118
119
	if usage.PromptTokens != 10 {
		t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens)
120
121
	}

122
123
	if usage.CompletionTokens != 20 {
		t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens)
124
125
	}

126
127
	if usage.TotalTokens != 30 {
		t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens)
128
	}
129
}
130

131
132
133
134
135
136
137
138
139
func TestNewError(t *testing.T) {
	tests := []struct {
		code int
		want string
	}{
		{400, "invalid_request_error"},
		{404, "not_found_error"},
		{500, "api_error"},
	}
140

141
142
143
144
	for _, tt := range tests {
		result := NewError(tt.code, "test message")
		if result.Error.Type != tt.want {
			t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want)
145
		}
146
147
		if result.Error.Message != "test message" {
			t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message")
148
		}
149
150
	}
}