openai_test.go 11.1 KB
Newer Older
1
2
3
4
package openai

import (
	"bytes"
5
	"encoding/base64"
6
7
8
9
	"encoding/json"
	"io"
	"net/http"
	"net/http/httptest"
10
	"reflect"
11
	"strings"
12
13
14
15
	"testing"
	"time"

	"github.com/gin-gonic/gin"
Michael Yang's avatar
lint  
Michael Yang committed
16
17

	"github.com/ollama/ollama/api"
18
19
)

Michael Yang's avatar
lint  
Michael Yang committed
20
const (
21
22
	prefix = `data:image/jpeg;base64,`
	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
Michael Yang's avatar
lint  
Michael Yang committed
23
)
24

25
var False = false
26
27
28
29
30
31
32
33
34
35
36
37
38
39

func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
	return func(c *gin.Context) {
		bodyBytes, _ := io.ReadAll(c.Request.Body)
		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
		err := json.Unmarshal(bodyBytes, capturedRequest)
		if err != nil {
			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
		}
		c.Next()
	}
}

func TestChatMiddleware(t *testing.T) {
40
	type testCase struct {
41
42
43
44
		name string
		body string
		req  api.ChatRequest
		err  ErrorResponse
45
46
	}

47
	var capturedRequest *api.ChatRequest
48

royjhan's avatar
royjhan committed
49
50
	testCases := []testCase{
		{
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
			name: "chat handler",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": "Hello"}
				]
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "Hello",
					},
				},
				Options: map[string]any{
					"temperature": 1.0,
					"top_p":       1.0,
				},
				Stream: &False,
71
			},
72
73
		},
		{
74
75
76
77
78
79
80
81
82
83
			name: "chat handler with image content",
			body: `{
				"model": "test-model",
				"messages": [
					{
						"role": "user",
						"content": [
							{
								"type": "text",
								"text": "Hello"
84
							},
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
							{
								"type": "image_url",
								"image_url": {
									"url": "` + prefix + image + `"
								}
							}
						]
					}
				]
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "Hello",
					},
					{
						Role: "user",
						Images: []api.ImageData{
							func() []byte {
								img, _ := base64.StdEncoding.DecodeString(image)
								return img
							}(),
109
110
						},
					},
111
112
113
114
115
116
				},
				Options: map[string]any{
					"temperature": 1.0,
					"top_p":       1.0,
				},
				Stream: &False,
117
118
119
			},
		},
		{
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
			name: "chat handler with tools",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": "What's the weather like in Paris Today?"},
					{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
				]
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "What's the weather like in Paris Today?",
					},
					{
						Role: "assistant",
						ToolCalls: []api.ToolCall{
							{
								Function: api.ToolCallFunction{
									Name: "get_current_weather",
									Arguments: map[string]interface{}{
										"location": "Paris, France",
										"format":   "celsius",
									},
								},
146
							},
147
						},
148
					},
149
150
151
152
153
154
				},
				Options: map[string]any{
					"temperature": 1.0,
					"top_p":       1.0,
				},
				Stream: &False,
155
156
			},
		},
157

158
159
160
161
162
163
164
165
166
167
168
169
170
		{
			name: "chat handler error forwarding",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": 2}
				]
			}`,
			err: ErrorResponse{
				Error: Error{
					Message: "invalid message content type: float64",
					Type:    "invalid_request_error",
				},
171
172
173
174
175
176
177
178
179
180
181
182
183
184
			},
		},
	}

	endpoint := func(c *gin.Context) {
		c.Status(http.StatusOK)
	}

	gin.SetMode(gin.TestMode)
	router := gin.New()
	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
	router.Handle(http.MethodPost, "/api/chat", endpoint)

	for _, tc := range testCases {
185
186
187
		t.Run(tc.name, func(t *testing.T) {
			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
			req.Header.Set("Content-Type", "application/json")
188
189
190
191

			resp := httptest.NewRecorder()
			router.ServeHTTP(resp, req)

192
193
194
195
196
197
198
199
200
			var errResp ErrorResponse
			if resp.Code != http.StatusOK {
				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
					t.Fatal(err)
				}
			}
			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
				t.Fatal("requests did not match")
			}
201

202
203
204
			if !reflect.DeepEqual(tc.err, errResp) {
				t.Fatal("errors did not match")
			}
205
206
207
208
209
210
211
			capturedRequest = nil
		})
	}
}

func TestCompletionsMiddleware(t *testing.T) {
	type testCase struct {
212
213
214
215
		name string
		body string
		req  api.GenerateRequest
		err  ErrorResponse
216
217
218
219
220
221
	}

	var capturedRequest *api.GenerateRequest

	testCases := []testCase{
		{
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
			name: "completions handler",
			body: `{
				"model": "test-model",
				"prompt": "Hello",
				"temperature": 0.8,
				"stop": ["\n", "stop"],
				"suffix": "suffix"
			}`,
			req: api.GenerateRequest{
				Model:  "test-model",
				Prompt: "Hello",
				Options: map[string]any{
					"frequency_penalty": 0.0,
					"presence_penalty":  0.0,
					"temperature":       1.6,
					"top_p":             1.0,
					"stop":              []any{"\n", "stop"},
				},
				Suffix: "suffix",
				Stream: &False,
242
243
			},
		},
244
		{
245
246
247
248
249
250
251
252
253
254
255
256
257
			name: "completions handler error forwarding",
			body: `{
				"model": "test-model",
				"prompt": "Hello",
				"temperature": null,
				"stop": [1, 2],
				"suffix": "suffix"
			}`,
			err: ErrorResponse{
				Error: Error{
					Message: "invalid type for 'stop' field: float64",
					Type:    "invalid_request_error",
				},
258
259
260
			},
		},
	}
261

262
263
264
	endpoint := func(c *gin.Context) {
		c.Status(http.StatusOK)
	}
265

266
267
268
269
	gin.SetMode(gin.TestMode)
	router := gin.New()
	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
	router.Handle(http.MethodPost, "/api/generate", endpoint)
270

271
	for _, tc := range testCases {
272
273
274
		t.Run(tc.name, func(t *testing.T) {
			req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
			req.Header.Set("Content-Type", "application/json")
275
276
277
278

			resp := httptest.NewRecorder()
			router.ServeHTTP(resp, req)

279
280
281
282
283
284
285
286
287
288
289
290
291
292
			var errResp ErrorResponse
			if resp.Code != http.StatusOK {
				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
					t.Fatal(err)
				}
			}

			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
				t.Fatal("requests did not match")
			}

			if !reflect.DeepEqual(tc.err, errResp) {
				t.Fatal("errors did not match")
			}
293
294
295
296
297
298
299
300

			capturedRequest = nil
		})
	}
}

func TestEmbeddingsMiddleware(t *testing.T) {
	type testCase struct {
301
302
303
304
		name string
		body string
		req  api.EmbedRequest
		err  ErrorResponse
305
306
307
308
309
	}

	var capturedRequest *api.EmbedRequest

	testCases := []testCase{
310
		{
311
312
313
314
315
316
317
318
			name: "embed handler single input",
			body: `{
				"input": "Hello",
				"model": "test-model"
			}`,
			req: api.EmbedRequest{
				Input: "Hello",
				Model: "test-model",
319
320
321
			},
		},
		{
322
323
324
325
326
327
328
329
			name: "embed handler batch input",
			body: `{
				"input": ["Hello", "World"],
				"model": "test-model"
			}`,
			req: api.EmbedRequest{
				Input: []any{"Hello", "World"},
				Model: "test-model",
330
331
			},
		},
332
		{
333
334
335
336
337
338
339
340
341
			name: "embed handler error forwarding",
			body: `{
				"model": "test-model"
			}`,
			err: ErrorResponse{
				Error: Error{
					Message: "invalid input",
					Type:    "invalid_request_error",
				},
342
343
344
			},
		},
	}
345

royjhan's avatar
royjhan committed
346
347
348
	endpoint := func(c *gin.Context) {
		c.Status(http.StatusOK)
	}
349

350
351
352
353
354
	gin.SetMode(gin.TestMode)
	router := gin.New()
	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
	router.Handle(http.MethodPost, "/api/embed", endpoint)

royjhan's avatar
royjhan committed
355
	for _, tc := range testCases {
356
357
358
		t.Run(tc.name, func(t *testing.T) {
			req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
			req.Header.Set("Content-Type", "application/json")
359

royjhan's avatar
royjhan committed
360
361
			resp := httptest.NewRecorder()
			router.ServeHTTP(resp, req)
362

363
364
365
366
367
368
369
370
371
372
373
374
375
376
			var errResp ErrorResponse
			if resp.Code != http.StatusOK {
				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
					t.Fatal(err)
				}
			}

			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
				t.Fatal("requests did not match")
			}

			if !reflect.DeepEqual(tc.err, errResp) {
				t.Fatal("errors did not match")
			}
377
378

			capturedRequest = nil
royjhan's avatar
royjhan committed
379
380
381
		})
	}
}
382

383
func TestListMiddleware(t *testing.T) {
royjhan's avatar
royjhan committed
384
	type testCase struct {
385
386
387
		name     string
		endpoint func(c *gin.Context)
		resp     string
royjhan's avatar
royjhan committed
388
389
390
	}

	testCases := []testCase{
391
		{
392
393
			name: "list handler",
			endpoint: func(c *gin.Context) {
394
395
396
				c.JSON(http.StatusOK, api.ListResponse{
					Models: []api.ListModelResponse{
						{
397
398
							Name:       "test-model",
							ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
399
400
401
402
						},
					},
				})
			},
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
			resp: `{
				"object": "list",
				"data": [
					{
						"id": "test-model",
						"object": "model",
						"created": 1686935002,
						"owned_by": "library"
					}
				]
			}`,
		},
		{
			name: "list handler empty output",
			endpoint: func(c *gin.Context) {
				c.JSON(http.StatusOK, api.ListResponse{})
			},
			resp: `{
				"object": "list",
				"data": null
			}`,
		},
	}
426

427
	gin.SetMode(gin.TestMode)
428

429
430
431
432
433
	for _, tc := range testCases {
		router := gin.New()
		router.Use(ListMiddleware())
		router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
		req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
		resp := httptest.NewRecorder()
		router.ServeHTTP(resp, req)

		var expected, actual map[string]any
		err := json.Unmarshal([]byte(tc.resp), &expected)
		if err != nil {
			t.Fatalf("failed to unmarshal expected response: %v", err)
		}

		err = json.Unmarshal(resp.Body.Bytes(), &actual)
		if err != nil {
			t.Fatalf("failed to unmarshal actual response: %v", err)
		}

		if !reflect.DeepEqual(expected, actual) {
			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
		}
	}
}

func TestRetrieveMiddleware(t *testing.T) {
	type testCase struct {
		name     string
		endpoint func(c *gin.Context)
		resp     string
	}

	testCases := []testCase{
463
		{
464
465
			name: "retrieve handler",
			endpoint: func(c *gin.Context) {
466
				c.JSON(http.StatusOK, api.ShowResponse{
467
					ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
468
469
				})
			},
470
471
472
473
474
475
476
477
478
479
480
			resp: `{
				"id":"test-model",
				"object":"model",
				"created":1686935002,
				"owned_by":"library"}
			`,
		},
		{
			name: "retrieve handler error forwarding",
			endpoint: func(c *gin.Context) {
				c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
481
			},
482
483
484
485
486
487
488
489
			resp: `{
				"error": {
				  "code": null,
				  "message": "model not found",
				  "param": null,
				  "type": "api_error"
				}
			}`,
490
491
492
493
494
495
		},
	}

	gin.SetMode(gin.TestMode)

	for _, tc := range testCases {
496
497
498
499
		router := gin.New()
		router.Use(RetrieveMiddleware())
		router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
		req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
500

501
502
		resp := httptest.NewRecorder()
		router.ServeHTTP(resp, req)
503

504
505
506
507
508
		var expected, actual map[string]any
		err := json.Unmarshal([]byte(tc.resp), &expected)
		if err != nil {
			t.Fatalf("failed to unmarshal expected response: %v", err)
		}
509

510
511
512
513
514
515
516
517
		err = json.Unmarshal(resp.Body.Bytes(), &actual)
		if err != nil {
			t.Fatalf("failed to unmarshal actual response: %v", err)
		}

		if !reflect.DeepEqual(expected, actual) {
			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
		}
518
519
	}
}