routes_test.go 9.57 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
	"encoding/binary"
7
8
	"encoding/json"
	"fmt"
9
10
11
	"io"
	"net/http"
	"net/http/httptest"
12
	"os"
Patrick Devine's avatar
Patrick Devine committed
13
	"sort"
14
	"strings"
15
16
17
	"testing"

	"github.com/stretchr/testify/assert"
Michael Yang's avatar
lint  
Michael Yang committed
18
	"github.com/stretchr/testify/require"
19

20
	"github.com/ollama/ollama/api"
21
	"github.com/ollama/ollama/envconfig"
22
	"github.com/ollama/ollama/llm"
23
	"github.com/ollama/ollama/parser"
24
	"github.com/ollama/ollama/types/model"
25
	"github.com/ollama/ollama/version"
26
27
)

28
29
func createTestFile(t *testing.T, name string) string {
	t.Helper()
30

31
	f, err := os.CreateTemp(t.TempDir(), name)
Michael Yang's avatar
Michael Yang committed
32
	require.NoError(t, err)
33
	defer f.Close()
34

35
	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
Michael Yang's avatar
Michael Yang committed
36
	require.NoError(t, err)
37

38
	err = binary.Write(f, binary.LittleEndian, uint32(3))
Michael Yang's avatar
Michael Yang committed
39
	require.NoError(t, err)
40

41
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
Michael Yang committed
42
	require.NoError(t, err)
43

44
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
Michael Yang committed
45
	require.NoError(t, err)
46

47
48
	return f.Name()
}
49

50
51
52
53
54
55
56
func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
57
58
59
	}

	createTestModel := func(t *testing.T, name string) {
60
61
		t.Helper()

62
63
		fname := createTestFile(t, "ollama-model")

Michael Yang's avatar
Michael Yang committed
64
		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
65
		modelfile, err := parser.ParseFile(r)
Michael Yang's avatar
lint  
Michael Yang committed
66
		require.NoError(t, err)
67
68
69
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
70
		err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
Michael Yang's avatar
lint  
Michael Yang committed
71
		require.NoError(t, err)
72
	}
73
74
75
76
77
78
79
80
81
82

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
83
				assert.Equal(t, "application/json; charset=utf-8", contentType)
84
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
85
				require.NoError(t, err)
86
				assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
87
88
			},
		},
89
90
91
92
93
94
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
95
				assert.Equal(t, "application/json; charset=utf-8", contentType)
96
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
97
				require.NoError(t, err)
98
99
100
101

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
102
				require.NoError(t, err)
103

104
				assert.NotNil(t, modelList.Models)
Michael Yang's avatar
lint  
Michael Yang committed
105
				assert.Empty(t, len(modelList.Models))
106
107
108
109
110
111
112
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
113
				createTestModel(t, "test-model")
114
115
116
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
117
				assert.Equal(t, "application/json; charset=utf-8", contentType)
118
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
119
				require.NoError(t, err)
120

121
122
				assert.NotContains(t, string(body), "expires_at")

123
124
				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
125
				require.NoError(t, err)
126

Michael Yang's avatar
lint  
Michael Yang committed
127
128
				assert.Len(t, modelList.Models, 1)
				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
129
130
131
132
133
134
135
			},
		},
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
Michael Yang's avatar
Michael Yang committed
136
				fname := createTestFile(t, "ollama-model")
137
138
139
140

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
Michael Yang's avatar
Michael Yang committed
141
					Modelfile: fmt.Sprintf("FROM %s", fname),
142
143
144
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
Michael Yang's avatar
lint  
Michael Yang committed
145
				require.NoError(t, err)
146
147
148
149
150
151
152

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, "application/json", contentType)
				_, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
153
154
				require.NoError(t, err)
				assert.Equal(t, 200, resp.StatusCode)
155
156

				model, err := GetModel("t-bone")
Michael Yang's avatar
lint  
Michael Yang committed
157
				require.NoError(t, err)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
				assert.Equal(t, "t-bone:latest", model.ShortName)
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
Michael Yang's avatar
lint  
Michael Yang committed
172
				require.NoError(t, err)
173
174
175
176
177

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
Michael Yang's avatar
lint  
Michael Yang committed
178
				require.NoError(t, err)
179
				assert.Equal(t, "beefsteak:latest", model.ShortName)
180
181
			},
		},
Patrick Devine's avatar
Patrick Devine committed
182
183
184
185
186
187
188
189
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
Michael Yang's avatar
lint  
Michael Yang committed
190
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
191
192
193
194
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
195
				assert.Equal(t, "application/json; charset=utf-8", contentType)
Patrick Devine's avatar
Patrick Devine committed
196
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
197
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
198
199
200

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
Michael Yang's avatar
lint  
Michael Yang committed
201
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
				assert.Equal(t, expectedParams, params)
216
				assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
Patrick Devine's avatar
Patrick Devine committed
217
218
			},
		},
219
220
	}

221
	t.Setenv("OLLAMA_MODELS", t.TempDir())
222
	envconfig.LoadConfig()
223

224
	s := &Server{}
225
226
227
228
229
230
	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
Michael Yang's avatar
Michael Yang committed
231
232
233
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
Michael Yang's avatar
lint  
Michael Yang committed
234
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
235
236
237
238
239
240

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
Michael Yang's avatar
lint  
Michael Yang committed
241
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
242
243
244
245
246
247
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
248
249
	}
}
250
251
252

func TestCase(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())
253
	envconfig.LoadConfig()
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

	cases := []string{
		"mistral",
		"llama3:latest",
		"library/phi3:q4_0",
		"registry.ollama.ai/library/gemma:q5_K_M",
		// TODO: host:port currently fails on windows (#4107)
		// "localhost:5000/alice/bob:latest",
	}

	var s Server
	for _, tt := range cases {
		t.Run(tt, func(t *testing.T) {
			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
				Name:      tt,
269
				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
				Stream:    &stream,
			})

			if w.Code != http.StatusOK {
				t.Fatalf("expected status 200 got %d", w.Code)
			}

			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
			if err != nil {
				t.Fatal(err)
			}

			t.Run("create", func(t *testing.T) {
				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
					Name:      strings.ToUpper(tt),
285
					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
					Stream:    &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("pull", func(t *testing.T) {
				w := createRequest(t, s.PullModelHandler, api.PullRequest{
					Name:   strings.ToUpper(tt),
					Stream: &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("copy", func(t *testing.T) {
				w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
					Source:      tt,
					Destination: strings.ToUpper(tt),
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})
		})
	}
}
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

func TestShow(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())
	envconfig.LoadConfig()

	var s Server

	createRequest(t, s.CreateModelHandler, api.CreateRequest{
		Name: "show-model",
		Modelfile: fmt.Sprintf(
			"FROM %s\nFROM %s",
			createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
			createBinFile(t, llm.KV{"general.architecture": "clip"}, nil),
		),
	})

	w := createRequest(t, s.ShowModelHandler, api.ShowRequest{
		Name: "show-model",
	})

	if w.Code != http.StatusOK {
		t.Fatalf("expected status code 200, actual %d", w.Code)
	}

	var resp api.ShowResponse
	if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
		t.Fatal(err)
	}

	if resp.ModelInfo["general.architecture"] != "test" {
		t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
	}

	if resp.ProjectorInfo["general.architecture"] != "clip" {
		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
	}
}