routes_test.go 19.7 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
	"encoding/binary"
7
8
	"encoding/json"
	"fmt"
9
	"io"
10
	"io/fs"
11
	"math"
12
13
	"math/rand/v2"
	"net"
14
15
	"net/http"
	"net/http/httptest"
16
	"os"
17
	"path/filepath"
Patrick Devine's avatar
Patrick Devine committed
18
	"sort"
19
	"strings"
20
	"testing"
21
	"unicode"
22

23
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
24
	"github.com/ollama/ollama/fs/ggml"
25
	"github.com/ollama/ollama/openai"
26
	"github.com/ollama/ollama/server/internal/client/ollama"
27
	"github.com/ollama/ollama/types/model"
28
	"github.com/ollama/ollama/version"
29
30
)

31
func createTestFile(t *testing.T, name string) (string, string) {
32
	t.Helper()
33

34
35
36
37
38
	modelDir := os.Getenv("OLLAMA_MODELS")
	if modelDir == "" {
		t.Fatalf("OLLAMA_MODELS not specified")
	}

39
	f, err := os.CreateTemp(t.TempDir(), name)
40
41
42
	if err != nil {
		t.Fatalf("failed to create temp file: %v", err)
	}
43
	defer f.Close()
44

45
	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
46
47
48
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
49

50
	err = binary.Write(f, binary.LittleEndian, uint32(3))
51
52
53
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
54

55
	err = binary.Write(f, binary.LittleEndian, uint64(0))
56
57
58
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
59

60
	err = binary.Write(f, binary.LittleEndian, uint64(0))
61
62
63
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
	// Calculate sha256 sum of file
	if _, err := f.Seek(0, 0); err != nil {
		t.Fatal(err)
	}

	digest, _ := GetSHA256Digest(f)
	if err := f.Close(); err != nil {
		t.Fatal(err)
	}

	if err := createLink(f.Name(), filepath.Join(modelDir, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))); err != nil {
		t.Fatal(err)
	}

	return f.Name(), digest
80
}
81

82
83
84
85
86
87
88
89
90
91
92
93
94
// equalStringSlices checks if two slices of strings are equal.
func equalStringSlices(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}
	for i := range a {
		if a[i] != b[i] {
			return false
		}
	}
	return true
}

95
96
97
98
99
100
101
102
103
type panicTransport struct{}

func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
	panic("unexpected RoundTrip call")
}

var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}

func TestRoutes(t *testing.T) {
104
105
106
107
108
109
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
110
111
112
	}

	createTestModel := func(t *testing.T, name string) {
113
114
		t.Helper()

115
		_, digest := createTestFile(t, "ollama-model")
116
117
118
119

		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
120
121
122
123
124
125
126
127
128
129
130
131
132
133

		r := api.CreateRequest{
			Name:  name,
			Files: map[string]string{"test.gguf": digest},
			Parameters: map[string]any{
				"seed":  42,
				"top_p": 0.9,
				"stop":  []string{"foo", "bar"},
			},
		}

		modelName := model.ParseName(name)

		baseLayers, err := ggufLayers(digest, fn)
134
135
136
		if err != nil {
			t.Fatalf("failed to create model: %v", err)
		}
137
138
139
140

		if err := createModel(r, modelName, baseLayers, fn); err != nil {
			t.Fatal(err)
		}
141
	}
142
143
144
145
146
147
148
149
150
151

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
152
153
154
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
155
				body, err := io.ReadAll(resp.Body)
156
157
158
159
160
161
162
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
				expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
				if string(body) != expectedBody {
					t.Errorf("expected body %s, got %s", expectedBody, string(body))
				}
163
164
			},
		},
165
166
167
168
169
170
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
171
172
173
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
174
				body, err := io.ReadAll(resp.Body)
175
176
177
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
178
179
180
181

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
182
183
184
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
185

186
187
188
				if modelList.Models == nil || len(modelList.Models) != 0 {
					t.Errorf("expected empty model list, got %v", modelList.Models)
				}
189
190
			},
		},
191
192
193
194
195
196
		{
			Name:   "openai empty list",
			Method: http.MethodGet,
			Path:   "/v1/models",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
197
198
199
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
200
				body, err := io.ReadAll(resp.Body)
201
202
203
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
204
205
206

				var modelList openai.ListCompletion
				err = json.Unmarshal(body, &modelList)
207
208
209
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
210

211
212
213
				if modelList.Object != "list" || len(modelList.Data) != 0 {
					t.Errorf("expected empty model list, got %v", modelList.Data)
				}
214
215
			},
		},
216
217
218
219
220
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
221
				createTestModel(t, "test-model")
222
223
224
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
225
226
227
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
228
				body, err := io.ReadAll(resp.Body)
229
230
231
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
232

233
234
235
				if strings.Contains(string(body), "expires_at") {
					t.Errorf("response body should not contain 'expires_at'")
				}
236

237
238
				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}

				if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
					t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
				}
			},
		},
		{
			Name:   "Delete Model Handler",
			Method: http.MethodDelete,
			Path:   "/api/delete",
			Setup: func(t *testing.T, req *http.Request) {
253
				createTestModel(t, "model_to_delete")
254
255

				deleteReq := api.DeleteRequest{
256
					Name: "model_to_delete",
257
258
259
260
261
				}
				jsonData, err := json.Marshal(deleteReq)
				if err != nil {
					t.Fatalf("failed to marshal delete request: %v", err)
				}
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				if resp.StatusCode != http.StatusOK {
					t.Errorf("expected status code 200, got %d", resp.StatusCode)
				}

				// Verify the model was deleted
				_, err := GetModel("model-to-delete")
				if err == nil || !os.IsNotExist(err) {
					t.Errorf("expected model to be deleted, got error %v", err)
				}
			},
		},
		{
			Name:   "Delete Non-existent Model",
			Method: http.MethodDelete,
			Path:   "/api/delete",
			Setup: func(t *testing.T, req *http.Request) {
				deleteReq := api.DeleteRequest{
283
					Name: "non_existent_model",
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
				}
				jsonData, err := json.Marshal(deleteReq)
				if err != nil {
					t.Fatalf("failed to marshal delete request: %v", err)
				}

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				if resp.StatusCode != http.StatusNotFound {
					t.Errorf("expected status code 404, got %d", resp.StatusCode)
				}

				body, err := io.ReadAll(resp.Body)
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}

				var errorResp map[string]string
				err = json.Unmarshal(body, &errorResp)
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}

				if !strings.Contains(errorResp["error"], "not found") {
					t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
				}
311
312
			},
		},
313
314
315
316
317
318
		{
			Name:   "openai list models with tags",
			Method: http.MethodGet,
			Path:   "/v1/models",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
319
320
321
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
322
				body, err := io.ReadAll(resp.Body)
323
324
325
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
326
327
328

				var modelList openai.ListCompletion
				err = json.Unmarshal(body, &modelList)
329
330
331
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
332

333
334
335
				if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
					t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
				}
336
337
			},
		},
338
339
340
341
342
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
343
				_, digest := createTestFile(t, "ollama-model")
344
345
				stream := false
				createReq := api.CreateRequest{
346
347
348
					Name:   "t-bone",
					Files:  map[string]string{"test.gguf": digest},
					Stream: &stream,
349
350
				}
				jsonData, err := json.Marshal(createReq)
351
352
353
				if err != nil {
					t.Fatalf("failed to marshal create request: %v", err)
				}
354
355
356
357
358

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
359
360
361
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
362
				_, err := io.ReadAll(resp.Body)
363
364
365
366
367
368
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
				if resp.StatusCode != http.StatusOK { // Updated line
					t.Errorf("expected status code 200, got %d", resp.StatusCode)
				}
369
370

				model, err := GetModel("t-bone")
371
372
373
374
375
376
				if err != nil {
					t.Fatalf("failed to get model: %v", err)
				}
				if model.ShortName != "t-bone:latest" {
					t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
				}
377
378
379
380
381
382
383
384
385
386
387
388
389
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
390
391
392
				if err != nil {
					t.Fatalf("failed to marshal copy request: %v", err)
				}
393
394
395
396
397

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
398
399
400
401
402
403
				if err != nil {
					t.Fatalf("failed to get model: %v", err)
				}
				if model.ShortName != "beefsteak:latest" {
					t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
				}
404
405
			},
		},
Patrick Devine's avatar
Patrick Devine committed
406
407
408
409
410
411
412
413
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
414
415
416
				if err != nil {
					t.Fatalf("failed to marshal show request: %v", err)
				}
Patrick Devine's avatar
Patrick Devine committed
417
418
419
420
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
421
422
423
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
Patrick Devine's avatar
Patrick Devine committed
424
				body, err := io.ReadAll(resp.Body)
425
426
427
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
Patrick Devine's avatar
Patrick Devine committed
428
429
430

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
431
432
433
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
Patrick Devine's avatar
Patrick Devine committed
434
435
436
437
438
439
440
441
442
443
444
445
446

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
447
448
449
450
451
452
453
454
455
456
				if !equalStringSlices(params, expectedParams) {
					t.Errorf("expected parameters %v, got %v", expectedParams, params)
				}
				paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
				if !ok {
					t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
				}
				if math.Abs(paramCount) > 1e-9 {
					t.Errorf("expected parameter count to be 0, got %f", paramCount)
				}
Patrick Devine's avatar
Patrick Devine committed
457
458
			},
		},
459
		{
460
461
462
463
			Name: "openai retrieve model handler",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
			},
464
465
466
467
			Method: http.MethodGet,
			Path:   "/v1/models/show-model",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
468
469
470
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
471
				body, err := io.ReadAll(resp.Body)
472
473
474
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
475
476
477

				var retrieveResp api.RetrieveModelResponse
				err = json.Unmarshal(body, &retrieveResp)
478
479
480
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
481

482
483
484
				if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
					t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
				}
485
486
			},
		},
487
488
	}

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
	modelsDir := t.TempDir()
	t.Setenv("OLLAMA_MODELS", modelsDir)

	rc := &ollama.Registry{
		// This is a temporary measure to allow us to move forward,
		// surfacing any code contacting ollama.com we do not intended
		// to.
		//
		// Currently, this only handles DELETE /api/delete, which
		// should not make any contact with the ollama.com registry, so
		// be clear about that.
		//
		// Tests that do need to contact the registry here, will be
		// consumed into our new server/api code packages and removed
		// from here.
		HTTPClient: panicOnRoundTrip,
	}
506

507
	s := &Server{}
508
	router, err := s.GenerateRoutes(rc)
509
510
511
	if err != nil {
		t.Fatalf("failed to generate routes: %v", err)
	}
512
513
514
515
516

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
Michael Yang's avatar
Michael Yang committed
517
518
519
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
520
521
522
			if err != nil {
				t.Fatalf("failed to create request: %v", err)
			}
Michael Yang's avatar
Michael Yang committed
523
524
525
526
527
528

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
529
530
531
			if err != nil {
				t.Fatalf("failed to do request: %v", err)
			}
Michael Yang's avatar
Michael Yang committed
532
533
534
535
536
537
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
538
539
	}
}
540

541
542
543
544
545
546
547
548
func casingShuffle(s string) string {
	rr := []rune(s)
	for i := range rr {
		if rand.N(2) == 0 {
			rr[i] = unicode.ToUpper(rr[i])
		} else {
			rr[i] = unicode.ToLower(rr[i])
		}
549
	}
550
551
	return string(rr)
}
552

553
554
func TestManifestCaseSensitivity(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())
555

556
557
558
559
560
561
562
563
564
565
566
567
568
	r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		io.WriteString(w, `{}`) //nolint:errcheck
	}))
	defer r.Close()

	nameUsed := make(map[string]bool)
	name := func() string {
		const fqmn = "example/namespace/model:tag"
		for {
			v := casingShuffle(fqmn)
			if nameUsed[v] {
				continue
569
			}
570
571
572
573
			nameUsed[v] = true
			return v
		}
	}
574

575
	wantStableName := name()
576

577
578
	t.Logf("stable name: %s", wantStableName)

579
580
581
582
	// checkManifestList tests that there is strictly one manifest in the
	// models directory, and that the manifest is for the model under test.
	checkManifestList := func() {
		t.Helper()
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
		mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
		var entries []string
		t.Logf("dir entries:")
		fsys := os.DirFS(mandir)
		err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
			if err != nil {
				return err
			}
			t.Logf("    %s", fs.FormatDirEntry(info))
			if info.IsDir() {
				return nil
			}
			path = strings.TrimPrefix(path, mandir)
			entries = append(entries, path)
			return nil
		})
		if err != nil {
			t.Fatalf("failed to walk directory: %v", err)
		}
603

604
605
606
607
		if len(entries) != 1 {
			t.Errorf("len(got) = %d, want 1", len(entries))
			return // do not use Fatal so following steps run
		}
608

609
610
611
612
613
614
615
616
		g := entries[0] // raw path
		g = filepath.ToSlash(g)
		w := model.ParseName(wantStableName).Filepath()
		w = filepath.ToSlash(w)
		if g != w {
			t.Errorf("\ngot:  %s\nwant: %s", g, w)
		}
	}
617

618
619
620
621
622
623
624
	checkOK := func(w *httptest.ResponseRecorder) {
		t.Helper()
		if w.Code != http.StatusOK {
			t.Errorf("code = %d, want 200", w.Code)
			t.Logf("body: %s", w.Body.String())
		}
	}
625

626
627
628
629
	var s Server
	testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
		var d net.Dialer
		return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
630
	}
631
632
633
	t.Cleanup(func() { testMakeRequestDialContext = nil })

	t.Logf("creating")
634
	_, digest := createBinFile(t, nil, nil)
635
636
637
	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
		// Start with the stable name, and later use a case-shuffled
		// version.
638
639
640
		Name:   wantStableName,
		Files:  map[string]string{"test.gguf": digest},
		Stream: &stream,
641
642
643
644
645
	}))
	checkManifestList()

	t.Logf("creating (again)")
	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
646
647
648
		Name:   name(),
		Files:  map[string]string{"test.gguf": digest},
		Stream: &stream,
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
	}))
	checkManifestList()

	t.Logf("pulling")
	checkOK(createRequest(t, s.PullHandler, api.PullRequest{
		Name:     name(),
		Stream:   &stream,
		Insecure: true,
	}))
	checkManifestList()

	t.Logf("copying")
	checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
		Source:      name(),
		Destination: name(),
	}))
	checkManifestList()
666
667
668
669
670
671
672
673
674
675
676
677

	t.Logf("pushing")
	rr := createRequest(t, s.PushHandler, api.PushRequest{
		Model:    name(),
		Insecure: true,
		Username: "alice",
		Password: "x",
	})
	checkOK(rr)
	if !strings.Contains(rr.Body.String(), `"status":"success"`) {
		t.Errorf("got = %q, want success", rr.Body.String())
	}
678
}
679
680
681
682
683
684

func TestShow(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())

	var s Server

Michael Yang's avatar
Michael Yang committed
685
686
	_, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
	_, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
687

688
	createRequest(t, s.CreateHandler, api.CreateRequest{
689
690
		Name:  "show-model",
		Files: map[string]string{"model.gguf": digest1, "projector.gguf": digest2},
691
692
	})

693
	w := createRequest(t, s.ShowHandler, api.ShowRequest{
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
		Name: "show-model",
	})

	if w.Code != http.StatusOK {
		t.Fatalf("expected status code 200, actual %d", w.Code)
	}

	var resp api.ShowResponse
	if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
		t.Fatal(err)
	}

	if resp.ModelInfo["general.architecture"] != "test" {
		t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
	}

	if resp.ProjectorInfo["general.architecture"] != "clip" {
		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
	}
}
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748

func TestNormalize(t *testing.T) {
	type testCase struct {
		input []float32
	}

	testCases := []testCase{
		{input: []float32{1}},
		{input: []float32{0, 1, 2, 3}},
		{input: []float32{0.1, 0.2, 0.3}},
		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
		{input: []float32{0, 0, 0}},
	}

	isNormalized := func(vec []float32) (res bool) {
		sum := 0.0
		for _, v := range vec {
			sum += float64(v * v)
		}
		if math.Abs(sum-1) > 1e-6 {
			return sum == 0
		} else {
			return true
		}
	}

	for _, tc := range testCases {
		t.Run("", func(t *testing.T) {
			normalized := normalize(tc.input)
			if !isNormalized(normalized) {
				t.Errorf("Vector %v is not normalized", tc.input)
			}
		})
	}
}