openai_test.go 12.1 KB
Newer Older
1
2
3
4
package openai

import (
	"bytes"
5
	"encoding/base64"
6
7
8
9
	"encoding/json"
	"io"
	"net/http"
	"net/http/httptest"
10
	"reflect"
11
	"strings"
12
13
14
15
	"testing"
	"time"

	"github.com/gin-gonic/gin"
Michael Yang's avatar
lint  
Michael Yang committed
16
17

	"github.com/ollama/ollama/api"
18
19
)

Michael Yang's avatar
lint  
Michael Yang committed
20
const (
21
22
	prefix = `data:image/jpeg;base64,`
	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
Michael Yang's avatar
lint  
Michael Yang committed
23
)
24

25
26
27
28
var (
	False = false
	True  = true
)
29
30
31
32
33
34
35
36
37
38
39
40
41
42

func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
	return func(c *gin.Context) {
		bodyBytes, _ := io.ReadAll(c.Request.Body)
		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
		err := json.Unmarshal(bodyBytes, capturedRequest)
		if err != nil {
			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
		}
		c.Next()
	}
}

func TestChatMiddleware(t *testing.T) {
43
	type testCase struct {
44
45
46
47
		name string
		body string
		req  api.ChatRequest
		err  ErrorResponse
48
49
	}

50
	var capturedRequest *api.ChatRequest
51

royjhan's avatar
royjhan committed
52
53
	testCases := []testCase{
		{
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
			name: "chat handler",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": "Hello"}
				]
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "Hello",
					},
				},
				Options: map[string]any{
					"temperature": 1.0,
					"top_p":       1.0,
				},
				Stream: &False,
74
			},
75
		},
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
		{
			name: "chat handler with options",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": "Hello"}
				],
				"stream":            true,
				"max_tokens":        999,
				"seed":              123,
				"stop":              ["\n", "stop"],
				"temperature":       3.0,
				"frequency_penalty": 4.0,
				"presence_penalty":  5.0,
				"top_p":             6.0,
				"response_format":   {"type": "json_object"}
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "Hello",
					},
				},
				Options: map[string]any{
					"num_predict":       999.0, // float because JSON doesn't distinguish between float and int
					"seed":              123.0,
					"stop":              []any{"\n", "stop"},
					"temperature":       6.0,
					"frequency_penalty": 8.0,
					"presence_penalty":  10.0,
					"top_p":             6.0,
				},
				Format: "json",
				Stream: &True,
			},
		},
114
		{
115
116
117
118
119
120
121
122
123
124
			name: "chat handler with image content",
			body: `{
				"model": "test-model",
				"messages": [
					{
						"role": "user",
						"content": [
							{
								"type": "text",
								"text": "Hello"
125
							},
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
							{
								"type": "image_url",
								"image_url": {
									"url": "` + prefix + image + `"
								}
							}
						]
					}
				]
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "Hello",
					},
					{
						Role: "user",
						Images: []api.ImageData{
							func() []byte {
								img, _ := base64.StdEncoding.DecodeString(image)
								return img
							}(),
150
151
						},
					},
152
153
154
155
156
157
				},
				Options: map[string]any{
					"temperature": 1.0,
					"top_p":       1.0,
				},
				Stream: &False,
158
159
160
			},
		},
		{
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
			name: "chat handler with tools",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": "What's the weather like in Paris Today?"},
					{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
				]
			}`,
			req: api.ChatRequest{
				Model: "test-model",
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "What's the weather like in Paris Today?",
					},
					{
						Role: "assistant",
						ToolCalls: []api.ToolCall{
							{
								Function: api.ToolCallFunction{
									Name: "get_current_weather",
									Arguments: map[string]interface{}{
										"location": "Paris, France",
										"format":   "celsius",
									},
								},
187
							},
188
						},
189
					},
190
191
192
193
194
195
				},
				Options: map[string]any{
					"temperature": 1.0,
					"top_p":       1.0,
				},
				Stream: &False,
196
197
			},
		},
198

199
200
201
202
203
204
205
206
207
208
209
210
211
		{
			name: "chat handler error forwarding",
			body: `{
				"model": "test-model",
				"messages": [
					{"role": "user", "content": 2}
				]
			}`,
			err: ErrorResponse{
				Error: Error{
					Message: "invalid message content type: float64",
					Type:    "invalid_request_error",
				},
212
213
214
215
216
217
218
219
220
221
222
223
224
225
			},
		},
	}

	endpoint := func(c *gin.Context) {
		c.Status(http.StatusOK)
	}

	gin.SetMode(gin.TestMode)
	router := gin.New()
	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
	router.Handle(http.MethodPost, "/api/chat", endpoint)

	for _, tc := range testCases {
226
227
228
		t.Run(tc.name, func(t *testing.T) {
			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
			req.Header.Set("Content-Type", "application/json")
229

230
231
			defer func() { capturedRequest = nil }()

232
233
234
			resp := httptest.NewRecorder()
			router.ServeHTTP(resp, req)

235
236
237
238
239
240
241
242
243
			var errResp ErrorResponse
			if resp.Code != http.StatusOK {
				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
					t.Fatal(err)
				}
			}
			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
				t.Fatal("requests did not match")
			}
244

245
246
247
			if !reflect.DeepEqual(tc.err, errResp) {
				t.Fatal("errors did not match")
			}
248
249
250
251
252
253
		})
	}
}

func TestCompletionsMiddleware(t *testing.T) {
	type testCase struct {
254
255
256
257
		name string
		body string
		req  api.GenerateRequest
		err  ErrorResponse
258
259
260
261
262
263
	}

	var capturedRequest *api.GenerateRequest

	testCases := []testCase{
		{
264
265
266
267
268
269
270
271
272
273
274
275
276
277
			name: "completions handler",
			body: `{
				"model": "test-model",
				"prompt": "Hello",
				"temperature": 0.8,
				"stop": ["\n", "stop"],
				"suffix": "suffix"
			}`,
			req: api.GenerateRequest{
				Model:  "test-model",
				Prompt: "Hello",
				Options: map[string]any{
					"frequency_penalty": 0.0,
					"presence_penalty":  0.0,
278
					"temperature":       0.8,
279
280
281
282
283
					"top_p":             1.0,
					"stop":              []any{"\n", "stop"},
				},
				Suffix: "suffix",
				Stream: &False,
284
285
			},
		},
286
		{
287
288
289
290
291
292
293
294
295
296
297
298
299
			name: "completions handler error forwarding",
			body: `{
				"model": "test-model",
				"prompt": "Hello",
				"temperature": null,
				"stop": [1, 2],
				"suffix": "suffix"
			}`,
			err: ErrorResponse{
				Error: Error{
					Message: "invalid type for 'stop' field: float64",
					Type:    "invalid_request_error",
				},
300
301
302
			},
		},
	}
303

304
305
306
	endpoint := func(c *gin.Context) {
		c.Status(http.StatusOK)
	}
307

308
309
310
311
	gin.SetMode(gin.TestMode)
	router := gin.New()
	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
	router.Handle(http.MethodPost, "/api/generate", endpoint)
312

313
	for _, tc := range testCases {
314
315
316
		t.Run(tc.name, func(t *testing.T) {
			req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
			req.Header.Set("Content-Type", "application/json")
317
318
319
320

			resp := httptest.NewRecorder()
			router.ServeHTTP(resp, req)

321
322
323
324
325
326
327
328
329
330
331
332
333
334
			var errResp ErrorResponse
			if resp.Code != http.StatusOK {
				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
					t.Fatal(err)
				}
			}

			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
				t.Fatal("requests did not match")
			}

			if !reflect.DeepEqual(tc.err, errResp) {
				t.Fatal("errors did not match")
			}
335
336
337
338
339
340
341
342

			capturedRequest = nil
		})
	}
}

func TestEmbeddingsMiddleware(t *testing.T) {
	type testCase struct {
343
344
345
346
		name string
		body string
		req  api.EmbedRequest
		err  ErrorResponse
347
348
349
350
351
	}

	var capturedRequest *api.EmbedRequest

	testCases := []testCase{
352
		{
353
354
355
356
357
358
359
360
			name: "embed handler single input",
			body: `{
				"input": "Hello",
				"model": "test-model"
			}`,
			req: api.EmbedRequest{
				Input: "Hello",
				Model: "test-model",
361
362
363
			},
		},
		{
364
365
366
367
368
369
370
371
			name: "embed handler batch input",
			body: `{
				"input": ["Hello", "World"],
				"model": "test-model"
			}`,
			req: api.EmbedRequest{
				Input: []any{"Hello", "World"},
				Model: "test-model",
372
373
			},
		},
374
		{
375
376
377
378
379
380
381
382
383
			name: "embed handler error forwarding",
			body: `{
				"model": "test-model"
			}`,
			err: ErrorResponse{
				Error: Error{
					Message: "invalid input",
					Type:    "invalid_request_error",
				},
384
385
386
			},
		},
	}
387

royjhan's avatar
royjhan committed
388
389
390
	endpoint := func(c *gin.Context) {
		c.Status(http.StatusOK)
	}
391

392
393
394
395
396
	gin.SetMode(gin.TestMode)
	router := gin.New()
	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
	router.Handle(http.MethodPost, "/api/embed", endpoint)

royjhan's avatar
royjhan committed
397
	for _, tc := range testCases {
398
399
400
		t.Run(tc.name, func(t *testing.T) {
			req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
			req.Header.Set("Content-Type", "application/json")
401

royjhan's avatar
royjhan committed
402
403
			resp := httptest.NewRecorder()
			router.ServeHTTP(resp, req)
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
			var errResp ErrorResponse
			if resp.Code != http.StatusOK {
				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
					t.Fatal(err)
				}
			}

			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
				t.Fatal("requests did not match")
			}

			if !reflect.DeepEqual(tc.err, errResp) {
				t.Fatal("errors did not match")
			}
419
420

			capturedRequest = nil
royjhan's avatar
royjhan committed
421
422
423
		})
	}
}
424

425
func TestListMiddleware(t *testing.T) {
royjhan's avatar
royjhan committed
426
	type testCase struct {
427
428
429
		name     string
		endpoint func(c *gin.Context)
		resp     string
royjhan's avatar
royjhan committed
430
431
432
	}

	testCases := []testCase{
433
		{
434
435
			name: "list handler",
			endpoint: func(c *gin.Context) {
436
437
438
				c.JSON(http.StatusOK, api.ListResponse{
					Models: []api.ListModelResponse{
						{
439
440
							Name:       "test-model",
							ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
441
442
443
444
						},
					},
				})
			},
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
			resp: `{
				"object": "list",
				"data": [
					{
						"id": "test-model",
						"object": "model",
						"created": 1686935002,
						"owned_by": "library"
					}
				]
			}`,
		},
		{
			name: "list handler empty output",
			endpoint: func(c *gin.Context) {
				c.JSON(http.StatusOK, api.ListResponse{})
			},
			resp: `{
				"object": "list",
				"data": null
			}`,
		},
	}
468

469
	gin.SetMode(gin.TestMode)
470

471
472
473
474
475
	for _, tc := range testCases {
		router := gin.New()
		router.Use(ListMiddleware())
		router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
		req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
		resp := httptest.NewRecorder()
		router.ServeHTTP(resp, req)

		var expected, actual map[string]any
		err := json.Unmarshal([]byte(tc.resp), &expected)
		if err != nil {
			t.Fatalf("failed to unmarshal expected response: %v", err)
		}

		err = json.Unmarshal(resp.Body.Bytes(), &actual)
		if err != nil {
			t.Fatalf("failed to unmarshal actual response: %v", err)
		}

		if !reflect.DeepEqual(expected, actual) {
			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
		}
	}
}

func TestRetrieveMiddleware(t *testing.T) {
	type testCase struct {
		name     string
		endpoint func(c *gin.Context)
		resp     string
	}

	testCases := []testCase{
505
		{
506
507
			name: "retrieve handler",
			endpoint: func(c *gin.Context) {
508
				c.JSON(http.StatusOK, api.ShowResponse{
509
					ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
510
511
				})
			},
512
513
514
515
516
517
518
519
520
521
522
			resp: `{
				"id":"test-model",
				"object":"model",
				"created":1686935002,
				"owned_by":"library"}
			`,
		},
		{
			name: "retrieve handler error forwarding",
			endpoint: func(c *gin.Context) {
				c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
523
			},
524
525
526
527
528
529
530
531
			resp: `{
				"error": {
				  "code": null,
				  "message": "model not found",
				  "param": null,
				  "type": "api_error"
				}
			}`,
532
533
534
535
536
537
		},
	}

	gin.SetMode(gin.TestMode)

	for _, tc := range testCases {
538
539
540
541
		router := gin.New()
		router.Use(RetrieveMiddleware())
		router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
		req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
542

543
544
		resp := httptest.NewRecorder()
		router.ServeHTTP(resp, req)
545

546
547
548
549
550
		var expected, actual map[string]any
		err := json.Unmarshal([]byte(tc.resp), &expected)
		if err != nil {
			t.Fatalf("failed to unmarshal expected response: %v", err)
		}
551

552
553
554
555
556
557
558
559
		err = json.Unmarshal(resp.Body.Bytes(), &actual)
		if err != nil {
			t.Fatalf("failed to unmarshal actual response: %v", err)
		}

		if !reflect.DeepEqual(expected, actual) {
			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
		}
560
561
	}
}