"official/legacy/transformer/transformer.py" did not exist on "513fdbb258b75755ccbaadae128350052fa69f72"
routes_test.go 17.9 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
4
5
6
7
8
9
package server

import (
	"bytes"
	"context"
	"encoding/binary"
	"encoding/json"
	"fmt"
	"io"
xuxzh1's avatar
update  
xuxzh1 committed
10
	"io/fs"
xuxzh1's avatar
init  
xuxzh1 committed
11
	"math"
xuxzh1's avatar
update  
xuxzh1 committed
12
13
	"math/rand/v2"
	"net"
mashun1's avatar
v1  
mashun1 committed
14
15
16
	"net/http"
	"net/http/httptest"
	"os"
xuxzh1's avatar
update  
xuxzh1 committed
17
	"path/filepath"
mashun1's avatar
v1  
mashun1 committed
18
19
20
	"sort"
	"strings"
	"testing"
xuxzh1's avatar
update  
xuxzh1 committed
21
	"unicode"
mashun1's avatar
v1  
mashun1 committed
22
23

	"github.com/ollama/ollama/api"
xuxzh1's avatar
init  
xuxzh1 committed
24
25
	"github.com/ollama/ollama/llm"
	"github.com/ollama/ollama/openai"
mashun1's avatar
v1  
mashun1 committed
26
	"github.com/ollama/ollama/parser"
xuxzh1's avatar
init  
xuxzh1 committed
27
	"github.com/ollama/ollama/types/model"
mashun1's avatar
v1  
mashun1 committed
28
29
30
31
32
33
34
	"github.com/ollama/ollama/version"
)

func createTestFile(t *testing.T, name string) string {
	t.Helper()

	f, err := os.CreateTemp(t.TempDir(), name)
xuxzh1's avatar
update  
xuxzh1 committed
35
36
37
	if err != nil {
		t.Fatalf("failed to create temp file: %v", err)
	}
mashun1's avatar
v1  
mashun1 committed
38
39
40
	defer f.Close()

	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
xuxzh1's avatar
update  
xuxzh1 committed
41
42
43
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
mashun1's avatar
v1  
mashun1 committed
44
45

	err = binary.Write(f, binary.LittleEndian, uint32(3))
xuxzh1's avatar
update  
xuxzh1 committed
46
47
48
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
mashun1's avatar
v1  
mashun1 committed
49
50

	err = binary.Write(f, binary.LittleEndian, uint64(0))
xuxzh1's avatar
update  
xuxzh1 committed
51
52
53
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
mashun1's avatar
v1  
mashun1 committed
54
55

	err = binary.Write(f, binary.LittleEndian, uint64(0))
xuxzh1's avatar
update  
xuxzh1 committed
56
57
58
	if err != nil {
		t.Fatalf("failed to write to file: %v", err)
	}
mashun1's avatar
v1  
mashun1 committed
59
60
61
62

	return f.Name()
}

xuxzh1's avatar
update  
xuxzh1 committed
63
64
65
66
67
68
69
70
71
72
73
74
75
// equalStringSlices checks if two slices of strings are equal.
func equalStringSlices(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}
	for i := range a {
		if a[i] != b[i] {
			return false
		}
	}
	return true
}

mashun1's avatar
v1  
mashun1 committed
76
77
78
79
80
81
82
83
84
85
func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
	}

	createTestModel := func(t *testing.T, name string) {
xuxzh1's avatar
init  
xuxzh1 committed
86
87
		t.Helper()

mashun1's avatar
v1  
mashun1 committed
88
89
90
91
		fname := createTestFile(t, "ollama-model")

		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
		modelfile, err := parser.ParseFile(r)
xuxzh1's avatar
update  
xuxzh1 committed
92
93
94
		if err != nil {
			t.Fatalf("failed to parse file: %v", err)
		}
mashun1's avatar
v1  
mashun1 committed
95
96
97
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
xuxzh1's avatar
init  
xuxzh1 committed
98
		err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
xuxzh1's avatar
update  
xuxzh1 committed
99
100
101
		if err != nil {
			t.Fatalf("failed to create model: %v", err)
		}
mashun1's avatar
v1  
mashun1 committed
102
103
104
105
106
107
108
109
110
111
112
	}

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
113
114
115
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
mashun1's avatar
v1  
mashun1 committed
116
				body, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
117
118
119
120
121
122
123
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
				expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
				if string(body) != expectedBody {
					t.Errorf("expected body %s, got %s", expectedBody, string(body))
				}
mashun1's avatar
v1  
mashun1 committed
124
125
126
127
128
129
130
131
			},
		},
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
132
133
134
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
mashun1's avatar
v1  
mashun1 committed
135
				body, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
136
137
138
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
139
140
141
142

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
xuxzh1's avatar
update  
xuxzh1 committed
143
144
145
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
146

xuxzh1's avatar
update  
xuxzh1 committed
147
148
149
				if modelList.Models == nil || len(modelList.Models) != 0 {
					t.Errorf("expected empty model list, got %v", modelList.Models)
				}
xuxzh1's avatar
init  
xuxzh1 committed
150
151
152
153
154
155
156
157
			},
		},
		{
			Name:   "openai empty list",
			Method: http.MethodGet,
			Path:   "/v1/models",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
158
159
160
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
xuxzh1's avatar
init  
xuxzh1 committed
161
				body, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
162
163
164
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
xuxzh1's avatar
init  
xuxzh1 committed
165
166
167

				var modelList openai.ListCompletion
				err = json.Unmarshal(body, &modelList)
xuxzh1's avatar
update  
xuxzh1 committed
168
169
170
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
xuxzh1's avatar
init  
xuxzh1 committed
171

xuxzh1's avatar
update  
xuxzh1 committed
172
173
174
				if modelList.Object != "list" || len(modelList.Data) != 0 {
					t.Errorf("expected empty model list, got %v", modelList.Data)
				}
mashun1's avatar
v1  
mashun1 committed
175
176
177
178
179
180
181
182
183
184
185
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "test-model")
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
186
187
188
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
mashun1's avatar
v1  
mashun1 committed
189
				body, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
190
191
192
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
xuxzh1's avatar
init  
xuxzh1 committed
193

xuxzh1's avatar
update  
xuxzh1 committed
194
195
196
				if strings.Contains(string(body), "expires_at") {
					t.Errorf("response body should not contain 'expires_at'")
				}
mashun1's avatar
v1  
mashun1 committed
197
198
199

				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
xuxzh1's avatar
update  
xuxzh1 committed
200
201
202
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
203

xuxzh1's avatar
update  
xuxzh1 committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
				if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
					t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
				}
			},
		},
		{
			Name:   "Delete Model Handler",
			Method: http.MethodDelete,
			Path:   "/api/delete",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "model-to-delete")

				deleteReq := api.DeleteRequest{
					Name: "model-to-delete",
				}
				jsonData, err := json.Marshal(deleteReq)
				if err != nil {
					t.Fatalf("failed to marshal delete request: %v", err)
				}

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				if resp.StatusCode != http.StatusOK {
					t.Errorf("expected status code 200, got %d", resp.StatusCode)
				}

				// Verify the model was deleted
				_, err := GetModel("model-to-delete")
				if err == nil || !os.IsNotExist(err) {
					t.Errorf("expected model to be deleted, got error %v", err)
				}
			},
		},
		{
			Name:   "Delete Non-existent Model",
			Method: http.MethodDelete,
			Path:   "/api/delete",
			Setup: func(t *testing.T, req *http.Request) {
				deleteReq := api.DeleteRequest{
					Name: "non-existent-model",
				}
				jsonData, err := json.Marshal(deleteReq)
				if err != nil {
					t.Fatalf("failed to marshal delete request: %v", err)
				}

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				if resp.StatusCode != http.StatusNotFound {
					t.Errorf("expected status code 404, got %d", resp.StatusCode)
				}

				body, err := io.ReadAll(resp.Body)
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}

				var errorResp map[string]string
				err = json.Unmarshal(body, &errorResp)
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}

				if !strings.Contains(errorResp["error"], "not found") {
					t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
				}
xuxzh1's avatar
init  
xuxzh1 committed
272
273
274
275
276
277
278
279
			},
		},
		{
			Name:   "openai list models with tags",
			Method: http.MethodGet,
			Path:   "/v1/models",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
280
281
282
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
xuxzh1's avatar
init  
xuxzh1 committed
283
				body, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
284
285
286
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
xuxzh1's avatar
init  
xuxzh1 committed
287
288
289

				var modelList openai.ListCompletion
				err = json.Unmarshal(body, &modelList)
xuxzh1's avatar
update  
xuxzh1 committed
290
291
292
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
xuxzh1's avatar
init  
xuxzh1 committed
293

xuxzh1's avatar
update  
xuxzh1 committed
294
295
296
				if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
					t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
				}
mashun1's avatar
v1  
mashun1 committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
			},
		},
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
				fname := createTestFile(t, "ollama-model")

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
					Modelfile: fmt.Sprintf("FROM %s", fname),
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
xuxzh1's avatar
update  
xuxzh1 committed
313
314
315
				if err != nil {
					t.Fatalf("failed to marshal create request: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
316
317
318
319
320

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
321
322
323
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
				}
mashun1's avatar
v1  
mashun1 committed
324
				_, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
325
326
327
328
329
330
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
				if resp.StatusCode != http.StatusOK { // Updated line
					t.Errorf("expected status code 200, got %d", resp.StatusCode)
				}
mashun1's avatar
v1  
mashun1 committed
331
332

				model, err := GetModel("t-bone")
xuxzh1's avatar
update  
xuxzh1 committed
333
334
335
336
337
338
				if err != nil {
					t.Fatalf("failed to get model: %v", err)
				}
				if model.ShortName != "t-bone:latest" {
					t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
				}
mashun1's avatar
v1  
mashun1 committed
339
340
341
342
343
344
345
346
347
348
349
350
351
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
xuxzh1's avatar
update  
xuxzh1 committed
352
353
354
				if err != nil {
					t.Fatalf("failed to marshal copy request: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
355
356
357
358
359

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
xuxzh1's avatar
update  
xuxzh1 committed
360
361
362
363
364
365
				if err != nil {
					t.Fatalf("failed to get model: %v", err)
				}
				if model.ShortName != "beefsteak:latest" {
					t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
				}
mashun1's avatar
v1  
mashun1 committed
366
367
368
369
370
371
372
373
374
375
			},
		},
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
xuxzh1's avatar
update  
xuxzh1 committed
376
377
378
				if err != nil {
					t.Fatalf("failed to marshal show request: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
379
380
381
382
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
383
384
385
				if contentType != "application/json; charset=utf-8" {
					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
				}
mashun1's avatar
v1  
mashun1 committed
386
				body, err := io.ReadAll(resp.Body)
xuxzh1's avatar
update  
xuxzh1 committed
387
388
389
				if err != nil {
					t.Fatalf("failed to read response body: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
390
391
392

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
xuxzh1's avatar
update  
xuxzh1 committed
393
394
395
				if err != nil {
					t.Fatalf("failed to unmarshal response body: %v", err)
				}
mashun1's avatar
v1  
mashun1 committed
396
397
398
399
400
401
402
403
404
405
406
407
408

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
xuxzh1's avatar
update  
xuxzh1 committed
409
410
411
412
413
414
415
416
417
418
				if !equalStringSlices(params, expectedParams) {
					t.Errorf("expected parameters %v, got %v", expectedParams, params)
				}
				paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
				if !ok {
					t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
				}
				if math.Abs(paramCount) > 1e-9 {
					t.Errorf("expected parameter count to be 0, got %f", paramCount)
				}
xuxzh1's avatar
init  
xuxzh1 committed
419
420
421
422
423
424
425
426
			},
		},
		{
			Name:   "openai retrieve model handler",
			Method: http.MethodGet,
			Path:   "/v1/models/show-model",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
xuxzh1's avatar
update  
xuxzh1 committed
427
428
				if contentType != "application/json" {
					t.Errorf("expected content type application/json, got %s", contentType)
xuxzh1's avatar
init  
xuxzh1 committed
429
430
431
				}
				body, err := io.ReadAll(resp.Body)
				if err != nil {
xuxzh1's avatar
update  
xuxzh1 committed
432
					t.Fatalf("failed to read response body: %v", err)
xuxzh1's avatar
init  
xuxzh1 committed
433
434
				}

xuxzh1's avatar
update  
xuxzh1 committed
435
436
				var retrieveResp api.RetrieveModelResponse
				err = json.Unmarshal(body, &retrieveResp)
xuxzh1's avatar
init  
xuxzh1 committed
437
				if err != nil {
xuxzh1's avatar
update  
xuxzh1 committed
438
					t.Fatalf("failed to unmarshal response body: %v", err)
xuxzh1's avatar
init  
xuxzh1 committed
439
440
				}

xuxzh1's avatar
update  
xuxzh1 committed
441
442
				if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
					t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
xuxzh1's avatar
init  
xuxzh1 committed
443
				}
mashun1's avatar
v1  
mashun1 committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
			},
		},
	}

	t.Setenv("OLLAMA_MODELS", t.TempDir())

	s := &Server{}
	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
xuxzh1's avatar
update  
xuxzh1 committed
460
461
462
			if err != nil {
				t.Fatalf("failed to create request: %v", err)
			}
mashun1's avatar
v1  
mashun1 committed
463
464
465
466
467
468

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
xuxzh1's avatar
update  
xuxzh1 committed
469
470
471
			if err != nil {
				t.Fatalf("failed to do request: %v", err)
			}
mashun1's avatar
v1  
mashun1 committed
472
473
474
475
476
477
478
479
480
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
	}
}

xuxzh1's avatar
update  
xuxzh1 committed
481
482
483
484
485
486
487
488
func casingShuffle(s string) string {
	rr := []rune(s)
	for i := range rr {
		if rand.N(2) == 0 {
			rr[i] = unicode.ToUpper(rr[i])
		} else {
			rr[i] = unicode.ToLower(rr[i])
		}
mashun1's avatar
v1  
mashun1 committed
489
	}
xuxzh1's avatar
update  
xuxzh1 committed
490
491
	return string(rr)
}
mashun1's avatar
v1  
mashun1 committed
492

xuxzh1's avatar
update  
xuxzh1 committed
493
494
func TestManifestCaseSensitivity(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())
mashun1's avatar
v1  
mashun1 committed
495

xuxzh1's avatar
update  
xuxzh1 committed
496
497
498
499
500
501
502
503
504
505
506
507
508
	r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		io.WriteString(w, `{}`) //nolint:errcheck
	}))
	defer r.Close()

	nameUsed := make(map[string]bool)
	name := func() string {
		const fqmn = "example/namespace/model:tag"
		for {
			v := casingShuffle(fqmn)
			if nameUsed[v] {
				continue
mashun1's avatar
v1  
mashun1 committed
509
			}
xuxzh1's avatar
update  
xuxzh1 committed
510
511
512
513
			nameUsed[v] = true
			return v
		}
	}
mashun1's avatar
v1  
mashun1 committed
514

xuxzh1's avatar
update  
xuxzh1 committed
515
	wantStableName := name()
mashun1's avatar
v1  
mashun1 committed
516

xuxzh1's avatar
update  
xuxzh1 committed
517
518
519
520
	// checkManifestList tests that there is strictly one manifest in the
	// models directory, and that the manifest is for the model under test.
	checkManifestList := func() {
		t.Helper()
mashun1's avatar
v1  
mashun1 committed
521

xuxzh1's avatar
update  
xuxzh1 committed
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
		mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
		var entries []string
		t.Logf("dir entries:")
		fsys := os.DirFS(mandir)
		err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
			if err != nil {
				return err
			}
			t.Logf("    %s", fs.FormatDirEntry(info))
			if info.IsDir() {
				return nil
			}
			path = strings.TrimPrefix(path, mandir)
			entries = append(entries, path)
			return nil
		})
		if err != nil {
			t.Fatalf("failed to walk directory: %v", err)
		}
mashun1's avatar
v1  
mashun1 committed
541

xuxzh1's avatar
update  
xuxzh1 committed
542
543
544
545
		if len(entries) != 1 {
			t.Errorf("len(got) = %d, want 1", len(entries))
			return // do not use Fatal so following steps run
		}
mashun1's avatar
v1  
mashun1 committed
546

xuxzh1's avatar
update  
xuxzh1 committed
547
548
549
550
551
552
553
554
		g := entries[0] // raw path
		g = filepath.ToSlash(g)
		w := model.ParseName(wantStableName).Filepath()
		w = filepath.ToSlash(w)
		if g != w {
			t.Errorf("\ngot:  %s\nwant: %s", g, w)
		}
	}
mashun1's avatar
v1  
mashun1 committed
555

xuxzh1's avatar
update  
xuxzh1 committed
556
557
558
559
560
561
562
	checkOK := func(w *httptest.ResponseRecorder) {
		t.Helper()
		if w.Code != http.StatusOK {
			t.Errorf("code = %d, want 200", w.Code)
			t.Logf("body: %s", w.Body.String())
		}
	}
mashun1's avatar
v1  
mashun1 committed
563

xuxzh1's avatar
update  
xuxzh1 committed
564
565
566
567
	var s Server
	testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
		var d net.Dialer
		return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
mashun1's avatar
v1  
mashun1 committed
568
	}
xuxzh1's avatar
update  
xuxzh1 committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
	t.Cleanup(func() { testMakeRequestDialContext = nil })

	t.Logf("creating")
	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
		// Start with the stable name, and later use a case-shuffled
		// version.
		Name: wantStableName,

		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
		Stream:    &stream,
	}))
	checkManifestList()

	t.Logf("creating (again)")
	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
		Name:      name(),
		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
		Stream:    &stream,
	}))
	checkManifestList()

	t.Logf("pulling")
	checkOK(createRequest(t, s.PullHandler, api.PullRequest{
		Name:     name(),
		Stream:   &stream,
		Insecure: true,
	}))
	checkManifestList()

	t.Logf("copying")
	checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
		Source:      name(),
		Destination: name(),
	}))
	checkManifestList()
mashun1's avatar
v1  
mashun1 committed
604
}
xuxzh1's avatar
init  
xuxzh1 committed
605
606
607
608
609
610

func TestShow(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())

	var s Server

xuxzh1's avatar
update  
xuxzh1 committed
611
	createRequest(t, s.CreateHandler, api.CreateRequest{
xuxzh1's avatar
init  
xuxzh1 committed
612
613
614
615
		Name: "show-model",
		Modelfile: fmt.Sprintf(
			"FROM %s\nFROM %s",
			createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
xuxzh1's avatar
update  
xuxzh1 committed
616
			createBinFile(t, llm.KV{"general.type": "projector", "general.architecture": "clip"}, nil),
xuxzh1's avatar
init  
xuxzh1 committed
617
618
619
		),
	})

xuxzh1's avatar
update  
xuxzh1 committed
620
	w := createRequest(t, s.ShowHandler, api.ShowRequest{
xuxzh1's avatar
init  
xuxzh1 committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
		Name: "show-model",
	})

	if w.Code != http.StatusOK {
		t.Fatalf("expected status code 200, actual %d", w.Code)
	}

	var resp api.ShowResponse
	if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
		t.Fatal(err)
	}

	if resp.ModelInfo["general.architecture"] != "test" {
		t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
	}

	if resp.ProjectorInfo["general.architecture"] != "clip" {
		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
	}
}

func TestNormalize(t *testing.T) {
	type testCase struct {
		input []float32
	}

	testCases := []testCase{
		{input: []float32{1}},
		{input: []float32{0, 1, 2, 3}},
		{input: []float32{0.1, 0.2, 0.3}},
		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
		{input: []float32{0, 0, 0}},
	}

	isNormalized := func(vec []float32) (res bool) {
		sum := 0.0
		for _, v := range vec {
			sum += float64(v * v)
		}
		if math.Abs(sum-1) > 1e-6 {
			return sum == 0
		} else {
			return true
		}
	}

	for _, tc := range testCases {
		t.Run("", func(t *testing.T) {
			normalized := normalize(tc.input)
			if !isNormalized(normalized) {
				t.Errorf("Vector %v is not normalized", tc.input)
			}
		})
	}
}