routes_test.go 8.22 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
	"encoding/binary"
7
8
	"encoding/json"
	"fmt"
9
10
11
	"io"
	"net/http"
	"net/http/httptest"
12
	"os"
Patrick Devine's avatar
Patrick Devine committed
13
	"sort"
14
	"strings"
15
16
17
	"testing"

	"github.com/stretchr/testify/assert"
Michael Yang's avatar
lint  
Michael Yang committed
18
	"github.com/stretchr/testify/require"
19

20
	"github.com/ollama/ollama/api"
21
	"github.com/ollama/ollama/parser"
22
	"github.com/ollama/ollama/version"
23
24
)

25
26
func createTestFile(t *testing.T, name string) string {
	t.Helper()
27

28
	f, err := os.CreateTemp(t.TempDir(), name)
Michael Yang's avatar
lint  
Michael Yang committed
29
	assert.NoError(t, err)
30
	defer f.Close()
31

32
	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
Michael Yang's avatar
lint  
Michael Yang committed
33
	assert.NoError(t, err)
34

35
	err = binary.Write(f, binary.LittleEndian, uint32(3))
Michael Yang's avatar
lint  
Michael Yang committed
36
	assert.NoError(t, err)
37

38
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
lint  
Michael Yang committed
39
	assert.NoError(t, err)
40

41
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
lint  
Michael Yang committed
42
	assert.NoError(t, err)
43

44
45
	return f.Name()
}
46

47
48
49
50
51
52
53
func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
54
55
56
57
58
	}

	createTestModel := func(t *testing.T, name string) {
		fname := createTestFile(t, "ollama-model")

Michael Yang's avatar
Michael Yang committed
59
		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
60
		modelfile, err := parser.ParseFile(r)
Michael Yang's avatar
lint  
Michael Yang committed
61
		require.NoError(t, err)
62
63
64
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
Michael Yang's avatar
Michael Yang committed
65
		err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
Michael Yang's avatar
lint  
Michael Yang committed
66
		require.NoError(t, err)
67
	}
68
69
70
71
72
73
74
75
76
77

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
78
				assert.Equal(t, "application/json; charset=utf-8", contentType)
79
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
80
				require.NoError(t, err)
81
				assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
82
83
			},
		},
84
85
86
87
88
89
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
90
				assert.Equal(t, "application/json; charset=utf-8", contentType)
91
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
92
				require.NoError(t, err)
93
94
95
96

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
97
				require.NoError(t, err)
98

99
				assert.NotNil(t, modelList.Models)
Michael Yang's avatar
lint  
Michael Yang committed
100
				assert.Empty(t, len(modelList.Models))
101
102
103
104
105
106
107
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
108
				createTestModel(t, "test-model")
109
110
111
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
112
				assert.Equal(t, "application/json; charset=utf-8", contentType)
113
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
114
				require.NoError(t, err)
115
116
117

				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
118
				require.NoError(t, err)
119

Michael Yang's avatar
lint  
Michael Yang committed
120
121
				assert.Len(t, modelList.Models, 1)
				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
122
123
124
125
126
127
128
			},
		},
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
Michael Yang's avatar
Michael Yang committed
129
				fname := createTestFile(t, "ollama-model")
130
131
132
133

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
Michael Yang's avatar
Michael Yang committed
134
					Modelfile: fmt.Sprintf("FROM %s", fname),
135
136
137
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
Michael Yang's avatar
lint  
Michael Yang committed
138
				require.NoError(t, err)
139
140
141
142
143
144
145

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, "application/json", contentType)
				_, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
146
147
				require.NoError(t, err)
				assert.Equal(t, 200, resp.StatusCode)
148
149

				model, err := GetModel("t-bone")
Michael Yang's avatar
lint  
Michael Yang committed
150
				require.NoError(t, err)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
				assert.Equal(t, "t-bone:latest", model.ShortName)
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
Michael Yang's avatar
lint  
Michael Yang committed
165
				require.NoError(t, err)
166
167
168
169
170

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
Michael Yang's avatar
lint  
Michael Yang committed
171
				require.NoError(t, err)
172
				assert.Equal(t, "beefsteak:latest", model.ShortName)
173
174
			},
		},
Patrick Devine's avatar
Patrick Devine committed
175
176
177
178
179
180
181
182
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
Michael Yang's avatar
lint  
Michael Yang committed
183
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
184
185
186
187
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
188
				assert.Equal(t, "application/json; charset=utf-8", contentType)
Patrick Devine's avatar
Patrick Devine committed
189
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
190
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
191
192
193

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
Michael Yang's avatar
lint  
Michael Yang committed
194
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
				assert.Equal(t, expectedParams, params)
			},
		},
211
212
	}

213
214
	t.Setenv("OLLAMA_MODELS", t.TempDir())

215
	s := &Server{}
216
217
218
219
220
221
	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
Michael Yang's avatar
Michael Yang committed
222
223
224
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
Michael Yang's avatar
lint  
Michael Yang committed
225
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
226
227
228
229
230
231

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
Michael Yang's avatar
lint  
Michael Yang committed
232
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
233
234
235
236
237
238
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
239
240
	}
}
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

func TestCase(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())

	cases := []string{
		"mistral",
		"llama3:latest",
		"library/phi3:q4_0",
		"registry.ollama.ai/library/gemma:q5_K_M",
		// TODO: host:port currently fails on windows (#4107)
		// "localhost:5000/alice/bob:latest",
	}

	var s Server
	for _, tt := range cases {
		t.Run(tt, func(t *testing.T) {
			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
				Name:      tt,
				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
				Stream:    &stream,
			})

			if w.Code != http.StatusOK {
				t.Fatalf("expected status 200 got %d", w.Code)
			}

			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
			if err != nil {
				t.Fatal(err)
			}

			t.Run("create", func(t *testing.T) {
				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
					Name:      strings.ToUpper(tt),
					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
					Stream:    &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("pull", func(t *testing.T) {
				w := createRequest(t, s.PullModelHandler, api.PullRequest{
					Name:   strings.ToUpper(tt),
					Stream: &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("copy", func(t *testing.T) {
				w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
					Source:      tt,
					Destination: strings.ToUpper(tt),
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})
		})
	}
}