routes_test.go 8.29 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
	"encoding/binary"
7
8
	"encoding/json"
	"fmt"
9
10
11
	"io"
	"net/http"
	"net/http/httptest"
12
	"os"
Patrick Devine's avatar
Patrick Devine committed
13
	"sort"
14
	"strings"
15
16
17
	"testing"

	"github.com/stretchr/testify/assert"
Michael Yang's avatar
lint  
Michael Yang committed
18
	"github.com/stretchr/testify/require"
19

20
	"github.com/ollama/ollama/api"
21
	"github.com/ollama/ollama/parser"
22
	"github.com/ollama/ollama/types/model"
23
	"github.com/ollama/ollama/version"
24
25
)

26
27
func createTestFile(t *testing.T, name string) string {
	t.Helper()
28

29
	f, err := os.CreateTemp(t.TempDir(), name)
Michael Yang's avatar
Michael Yang committed
30
	require.NoError(t, err)
31
	defer f.Close()
32

33
	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
Michael Yang's avatar
Michael Yang committed
34
	require.NoError(t, err)
35

36
	err = binary.Write(f, binary.LittleEndian, uint32(3))
Michael Yang's avatar
Michael Yang committed
37
	require.NoError(t, err)
38

39
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
Michael Yang committed
40
	require.NoError(t, err)
41

42
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
Michael Yang committed
43
	require.NoError(t, err)
44

45
46
	return f.Name()
}
47

48
49
50
51
52
53
54
func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
55
56
57
	}

	createTestModel := func(t *testing.T, name string) {
58
59
		t.Helper()

60
61
		fname := createTestFile(t, "ollama-model")

Michael Yang's avatar
Michael Yang committed
62
		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
63
		modelfile, err := parser.ParseFile(r)
Michael Yang's avatar
lint  
Michael Yang committed
64
		require.NoError(t, err)
65
66
67
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
68
		err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
Michael Yang's avatar
lint  
Michael Yang committed
69
		require.NoError(t, err)
70
	}
71
72
73
74
75
76
77
78
79
80

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
81
				assert.Equal(t, "application/json; charset=utf-8", contentType)
82
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
83
				require.NoError(t, err)
84
				assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
85
86
			},
		},
87
88
89
90
91
92
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
93
				assert.Equal(t, "application/json; charset=utf-8", contentType)
94
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
95
				require.NoError(t, err)
96
97
98
99

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
100
				require.NoError(t, err)
101

102
				assert.NotNil(t, modelList.Models)
Michael Yang's avatar
lint  
Michael Yang committed
103
				assert.Empty(t, len(modelList.Models))
104
105
106
107
108
109
110
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
111
				createTestModel(t, "test-model")
112
113
114
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
115
				assert.Equal(t, "application/json; charset=utf-8", contentType)
116
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
117
				require.NoError(t, err)
118
119
120

				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
121
				require.NoError(t, err)
122

Michael Yang's avatar
lint  
Michael Yang committed
123
124
				assert.Len(t, modelList.Models, 1)
				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
125
126
127
128
129
130
131
			},
		},
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
Michael Yang's avatar
Michael Yang committed
132
				fname := createTestFile(t, "ollama-model")
133
134
135
136

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
Michael Yang's avatar
Michael Yang committed
137
					Modelfile: fmt.Sprintf("FROM %s", fname),
138
139
140
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
Michael Yang's avatar
lint  
Michael Yang committed
141
				require.NoError(t, err)
142
143
144
145
146
147
148

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, "application/json", contentType)
				_, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
149
150
				require.NoError(t, err)
				assert.Equal(t, 200, resp.StatusCode)
151
152

				model, err := GetModel("t-bone")
Michael Yang's avatar
lint  
Michael Yang committed
153
				require.NoError(t, err)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
				assert.Equal(t, "t-bone:latest", model.ShortName)
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
Michael Yang's avatar
lint  
Michael Yang committed
168
				require.NoError(t, err)
169
170
171
172
173

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
Michael Yang's avatar
lint  
Michael Yang committed
174
				require.NoError(t, err)
175
				assert.Equal(t, "beefsteak:latest", model.ShortName)
176
177
			},
		},
Patrick Devine's avatar
Patrick Devine committed
178
179
180
181
182
183
184
185
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
Michael Yang's avatar
lint  
Michael Yang committed
186
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
187
188
189
190
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
191
				assert.Equal(t, "application/json; charset=utf-8", contentType)
Patrick Devine's avatar
Patrick Devine committed
192
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
193
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
194
195
196

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
Michael Yang's avatar
lint  
Michael Yang committed
197
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
				assert.Equal(t, expectedParams, params)
			},
		},
214
215
	}

216
217
	t.Setenv("OLLAMA_MODELS", t.TempDir())

218
	s := &Server{}
219
220
221
222
223
224
	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
Michael Yang's avatar
Michael Yang committed
225
226
227
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
Michael Yang's avatar
lint  
Michael Yang committed
228
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
229
230
231
232
233
234

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
Michael Yang's avatar
lint  
Michael Yang committed
235
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
236
237
238
239
240
241
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
242
243
	}
}
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

func TestCase(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())

	cases := []string{
		"mistral",
		"llama3:latest",
		"library/phi3:q4_0",
		"registry.ollama.ai/library/gemma:q5_K_M",
		// TODO: host:port currently fails on windows (#4107)
		// "localhost:5000/alice/bob:latest",
	}

	var s Server
	for _, tt := range cases {
		t.Run(tt, func(t *testing.T) {
			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
				Name:      tt,
				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
				Stream:    &stream,
			})

			if w.Code != http.StatusOK {
				t.Fatalf("expected status 200 got %d", w.Code)
			}

			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
			if err != nil {
				t.Fatal(err)
			}

			t.Run("create", func(t *testing.T) {
				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
					Name:      strings.ToUpper(tt),
					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
					Stream:    &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("pull", func(t *testing.T) {
				w := createRequest(t, s.PullModelHandler, api.PullRequest{
					Name:   strings.ToUpper(tt),
					Stream: &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("copy", func(t *testing.T) {
				w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
					Source:      tt,
					Destination: strings.ToUpper(tt),
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})
		})
	}
}