routes_test.go 16.8 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
	"encoding/binary"
7
8
	"encoding/json"
	"fmt"
9
	"io"
10
	"math"
11
12
	"net/http"
	"net/http/httptest"
13
	"os"
Patrick Devine's avatar
Patrick Devine committed
14
	"sort"
15
	"strings"
16
17
	"testing"

18
	"github.com/ollama/ollama/api"
19
	"github.com/ollama/ollama/llm"
20
	"github.com/ollama/ollama/openai"
21
	"github.com/ollama/ollama/parser"
22
	"github.com/ollama/ollama/types/model"
23
	"github.com/ollama/ollama/version"
24
25
)

26
27
func createTestFile(t *testing.T, name string) string {
	t.Helper()
28

29
	f, err := os.CreateTemp(t.TempDir(), name)
30
31
32
	if err != nil {
		t.Fatalf("failed to create temp file: %v", err)
	}
33
	defer f.Close()
34

35
	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
36
37
38
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
39

40
	err = binary.Write(f, binary.LittleEndian, uint32(3))
41
42
43
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
44

45
	err = binary.Write(f, binary.LittleEndian, uint64(0))
46
47
48
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
49

50
	err = binary.Write(f, binary.LittleEndian, uint64(0))
51
52
53
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
54

55
56
	return f.Name()
}
57

58
59
60
61
62
63
64
65
66
67
68
69
70
// equalStringSlices checks if two slices of strings are equal.
func equalStringSlices(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}
	for i := range a {
		if a[i] != b[i] {
			return false
		}
	}
	return true
}

71
72
73
74
75
76
77
func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
78
79
80
	}

	createTestModel := func(t *testing.T, name string) {
81
82
		t.Helper()

83
84
		fname := createTestFile(t, "ollama-model")

Michael Yang's avatar
Michael Yang committed
85
		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
86
		modelfile, err := parser.ParseFile(r)
87
88
89
		if err != nil {
			t.Fatalf("failed to parse file: %v", err)
		}
90
91
92
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
93
		err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
94
95
96
		if err != nil {
			t.Fatalf("failed to create model: %v", err)
		}
97
	}
98
99
100
101
102
103
104
105
106
107

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
108
109
110
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
111
				body, err := io.ReadAll(resp.Body)
112
113
114
115
116
117
118
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
				expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
				if string(body) != expectedBody {
					t.Errorf("expected body %s, got %s", expectedBody, string(body))
				}
119
120
			},
		},
121
122
123
124
125
126
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
127
128
129
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
130
				body, err := io.ReadAll(resp.Body)
131
132
133
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
134
135
136
137

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
138
139
140
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
141

142
143
144
				if modelList.Models == nil || len(modelList.Models) != 0 {
					t.Errorf("expected empty model list, got %v", modelList.Models)
				}
145
146
			},
		},
147
148
149
150
151
152
		{
			Name:   "openai empty list",
			Method: http.MethodGet,
			Path:   "/v1/models",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
153
154
155
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
156
				body, err := io.ReadAll(resp.Body)
157
158
159
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
160
161
162

				var modelList openai.ListCompletion
				err = json.Unmarshal(body, &modelList)
163
164
165
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
166

167
168
169
				if modelList.Object != "list" || len(modelList.Data) != 0 {
					t.Errorf("expected empty model list, got %v", modelList.Data)
				}
170
171
			},
		},
172
173
174
175
176
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
177
				createTestModel(t, "test-model")
178
179
180
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
181
182
183
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
184
				body, err := io.ReadAll(resp.Body)
185
186
187
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
188

189
190
191
				if strings.Contains(string(body), "expires_at") {
					t.Errorf("response body should not contain 'expires_at'")
				}
192

193
194
				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}

				if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
					t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
				}
			},
		},
		{
			Name:   "Delete Model Handler",
			Method: http.MethodDelete,
			Path:   "/api/delete",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "model-to-delete")

				deleteReq := api.DeleteRequest{
					Name: "model-to-delete",
				}
				jsonData, err := json.Marshal(deleteReq)
				if err != nil {
					t.Fatalf("failed to marshal delete request: %v", err)
				}
218

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				if resp.StatusCode != http.StatusOK {
					t.Errorf("expected status code 200, got %d", resp.StatusCode)
				}

				// Verify the model was deleted
				_, err := GetModel("model-to-delete")
				if err == nil || !os.IsNotExist(err) {
					t.Errorf("expected model to be deleted, got error %v", err)
				}
			},
		},
		{
			Name:   "Delete Non-existent Model",
			Method: http.MethodDelete,
			Path:   "/api/delete",
			Setup: func(t *testing.T, req *http.Request) {
				deleteReq := api.DeleteRequest{
					Name: "non-existent-model",
				}
				jsonData, err := json.Marshal(deleteReq)
				if err != nil {
					t.Fatalf("failed to marshal delete request: %v", err)
				}

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				if resp.StatusCode != http.StatusNotFound {
					t.Errorf("expected status code 404, got %d", resp.StatusCode)
				}

				body, err := io.ReadAll(resp.Body)
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}

				var errorResp map[string]string
				err = json.Unmarshal(body, &errorResp)
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}

				if !strings.Contains(errorResp["error"], "not found") {
					t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
				}
267
268
			},
		},
269
270
271
272
273
274
		{
			Name:   "openai list models with tags",
			Method: http.MethodGet,
			Path:   "/v1/models",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
275
276
277
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
278
				body, err := io.ReadAll(resp.Body)
279
280
281
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
282
283
284

				var modelList openai.ListCompletion
				err = json.Unmarshal(body, &modelList)
285
286
287
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
288

289
290
291
				if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
					t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
				}
292
293
			},
		},
294
295
296
297
298
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
Michael Yang's avatar
Michael Yang committed
299
				fname := createTestFile(t, "ollama-model")
300
301
302
303

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
Michael Yang's avatar
Michael Yang committed
304
					Modelfile: fmt.Sprintf("FROM %s", fname),
305
306
307
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
308
309
310
				if err != nil {
					t.Fatalf("failed to marshal create request: %v", err)
				}
311
312
313
314
315

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
316
317
318
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
319
				_, err := io.ReadAll(resp.Body)
320
321
322
323
324
325
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
				if resp.StatusCode != http.StatusOK { // Updated line
					t.Errorf("expected status code 200, got %d", resp.StatusCode)
				}
326
327

				model, err := GetModel("t-bone")
328
329
330
331
332
333
				if err != nil {
					t.Fatalf("failed to get model: %v", err)
				}
				if model.ShortName != "t-bone:latest" {
					t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
				}
334
335
336
337
338
339
340
341
342
343
344
345
346
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
347
348
349
				if err != nil {
					t.Fatalf("failed to marshal copy request: %v", err)
				}
350
351
352
353
354

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
355
356
357
358
359
360
				if err != nil {
					t.Fatalf("failed to get model: %v", err)
				}
				if model.ShortName != "beefsteak:latest" {
					t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
				}
361
362
			},
		},
Patrick Devine's avatar
Patrick Devine committed
363
364
365
366
367
368
369
370
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
371
372
373
				if err != nil {
					t.Fatalf("failed to marshal show request: %v", err)
				}
Patrick Devine's avatar
Patrick Devine committed
374
375
376
377
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
378
379
380
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
Patrick Devine's avatar
Patrick Devine committed
381
				body, err := io.ReadAll(resp.Body)
382
383
384
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
Patrick Devine's avatar
Patrick Devine committed
385
386
387

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
388
389
390
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
Patrick Devine's avatar
Patrick Devine committed
391
392
393
394
395
396
397
398
399
400
401
402
403

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
404
405
406
407
408
409
410
411
412
413
				if !equalStringSlices(params, expectedParams) {
					t.Errorf("expected parameters %v, got %v", expectedParams, params)
				}
				paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
				if !ok {
					t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
				}
				if math.Abs(paramCount) > 1e-9 {
					t.Errorf("expected parameter count to be 0, got %f", paramCount)
				}
Patrick Devine's avatar
Patrick Devine committed
414
415
			},
		},
416
417
418
419
420
421
		{
			Name:   "openai retrieve model handler",
			Method: http.MethodGet,
			Path:   "/v1/models/show-model",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
422
423
424
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
425
				body, err := io.ReadAll(resp.Body)
426
427
428
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
429
430
431

				var retrieveResp api.RetrieveModelResponse
				err = json.Unmarshal(body, &retrieveResp)
432
433
434
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
435

436
437
438
				if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
					t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
				}
439
440
			},
		},
441
442
	}

443
444
	t.Setenv("OLLAMA_MODELS", t.TempDir())

445
	s := &Server{}
446
447
448
449
450
451
	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
Michael Yang's avatar
Michael Yang committed
452
453
454
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
455
456
457
			if err != nil {
				t.Fatalf("failed to create request: %v", err)
			}
Michael Yang's avatar
Michael Yang committed
458
459
460
461
462
463

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
464
465
466
			if err != nil {
				t.Fatalf("failed to do request: %v", err)
			}
Michael Yang's avatar
Michael Yang committed
467
468
469
470
471
472
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
473
474
	}
}
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

func TestCase(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())

	cases := []string{
		"mistral",
		"llama3:latest",
		"library/phi3:q4_0",
		"registry.ollama.ai/library/gemma:q5_K_M",
		// TODO: host:port currently fails on windows (#4107)
		// "localhost:5000/alice/bob:latest",
	}

	var s Server
	for _, tt := range cases {
		t.Run(tt, func(t *testing.T) {
491
			w := createRequest(t, s.CreateHandler, api.CreateRequest{
492
				Name:      tt,
493
				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
494
495
496
497
498
499
500
501
502
503
504
505
506
				Stream:    &stream,
			})

			if w.Code != http.StatusOK {
				t.Fatalf("expected status 200 got %d", w.Code)
			}

			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
			if err != nil {
				t.Fatal(err)
			}

			t.Run("create", func(t *testing.T) {
507
				w = createRequest(t, s.CreateHandler, api.CreateRequest{
508
					Name:      strings.ToUpper(tt),
509
					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
510
511
512
513
514
515
516
517
518
519
520
521
522
					Stream:    &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("pull", func(t *testing.T) {
523
				w := createRequest(t, s.PullHandler, api.PullRequest{
524
525
526
527
528
529
530
531
532
533
534
535
536
537
					Name:   strings.ToUpper(tt),
					Stream: &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("copy", func(t *testing.T) {
538
				w := createRequest(t, s.CopyHandler, api.CopyRequest{
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
					Source:      tt,
					Destination: strings.ToUpper(tt),
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})
		})
	}
}
554
555
556
557
558
559

func TestShow(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())

	var s Server

560
	createRequest(t, s.CreateHandler, api.CreateRequest{
561
562
563
564
565
566
567
568
		Name: "show-model",
		Modelfile: fmt.Sprintf(
			"FROM %s\nFROM %s",
			createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
			createBinFile(t, llm.KV{"general.architecture": "clip"}, nil),
		),
	})

569
	w := createRequest(t, s.ShowHandler, api.ShowRequest{
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
		Name: "show-model",
	})

	if w.Code != http.StatusOK {
		t.Fatalf("expected status code 200, actual %d", w.Code)
	}

	var resp api.ShowResponse
	if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
		t.Fatal(err)
	}

	if resp.ModelInfo["general.architecture"] != "test" {
		t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
	}

	if resp.ProjectorInfo["general.architecture"] != "clip" {
		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
	}
}
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624

func TestNormalize(t *testing.T) {
	type testCase struct {
		input []float32
	}

	testCases := []testCase{
		{input: []float32{1}},
		{input: []float32{0, 1, 2, 3}},
		{input: []float32{0.1, 0.2, 0.3}},
		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
		{input: []float32{0, 0, 0}},
	}

	isNormalized := func(vec []float32) (res bool) {
		sum := 0.0
		for _, v := range vec {
			sum += float64(v * v)
		}
		if math.Abs(sum-1) > 1e-6 {
			return sum == 0
		} else {
			return true
		}
	}

	for _, tc := range testCases {
		t.Run("", func(t *testing.T) {
			normalized := normalize(tc.input)
			if !isNormalized(normalized) {
				t.Errorf("Vector %v is not normalized", tc.input)
			}
		})
	}
}