routes_test.go 8.45 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
	"encoding/binary"
7
8
	"encoding/json"
	"fmt"
9
10
11
	"io"
	"net/http"
	"net/http/httptest"
12
	"os"
Patrick Devine's avatar
Patrick Devine committed
13
	"sort"
14
	"strings"
15
16
17
	"testing"

	"github.com/stretchr/testify/assert"
Michael Yang's avatar
lint  
Michael Yang committed
18
	"github.com/stretchr/testify/require"
19

20
	"github.com/ollama/ollama/api"
21
	"github.com/ollama/ollama/envconfig"
22
	"github.com/ollama/ollama/parser"
23
	"github.com/ollama/ollama/types/model"
24
	"github.com/ollama/ollama/version"
25
26
)

27
28
func createTestFile(t *testing.T, name string) string {
	t.Helper()
29

30
	f, err := os.CreateTemp(t.TempDir(), name)
Michael Yang's avatar
Michael Yang committed
31
	require.NoError(t, err)
32
	defer f.Close()
33

34
	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
Michael Yang's avatar
Michael Yang committed
35
	require.NoError(t, err)
36

37
	err = binary.Write(f, binary.LittleEndian, uint32(3))
Michael Yang's avatar
Michael Yang committed
38
	require.NoError(t, err)
39

40
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
Michael Yang committed
41
	require.NoError(t, err)
42

43
	err = binary.Write(f, binary.LittleEndian, uint64(0))
Michael Yang's avatar
Michael Yang committed
44
	require.NoError(t, err)
45

46
47
	return f.Name()
}
48

49
50
51
52
53
54
55
func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
56
57
58
	}

	createTestModel := func(t *testing.T, name string) {
59
60
		t.Helper()

61
62
		fname := createTestFile(t, "ollama-model")

Michael Yang's avatar
Michael Yang committed
63
		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
64
		modelfile, err := parser.ParseFile(r)
Michael Yang's avatar
lint  
Michael Yang committed
65
		require.NoError(t, err)
66
67
68
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
69
		err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
Michael Yang's avatar
lint  
Michael Yang committed
70
		require.NoError(t, err)
71
	}
72
73
74
75
76
77
78
79
80
81

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
82
				assert.Equal(t, "application/json; charset=utf-8", contentType)
83
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
84
				require.NoError(t, err)
85
				assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
86
87
			},
		},
88
89
90
91
92
93
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
94
				assert.Equal(t, "application/json; charset=utf-8", contentType)
95
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
96
				require.NoError(t, err)
97
98
99
100

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
101
				require.NoError(t, err)
102

103
				assert.NotNil(t, modelList.Models)
Michael Yang's avatar
lint  
Michael Yang committed
104
				assert.Empty(t, len(modelList.Models))
105
106
107
108
109
110
111
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
112
				createTestModel(t, "test-model")
113
114
115
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
116
				assert.Equal(t, "application/json; charset=utf-8", contentType)
117
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
118
				require.NoError(t, err)
119

120
121
				assert.NotContains(t, string(body), "expires_at")

122
123
				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
Michael Yang's avatar
lint  
Michael Yang committed
124
				require.NoError(t, err)
125

Michael Yang's avatar
lint  
Michael Yang committed
126
127
				assert.Len(t, modelList.Models, 1)
				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
128
129
130
131
132
133
134
			},
		},
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
Michael Yang's avatar
Michael Yang committed
135
				fname := createTestFile(t, "ollama-model")
136
137
138
139

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
Michael Yang's avatar
Michael Yang committed
140
					Modelfile: fmt.Sprintf("FROM %s", fname),
141
142
143
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
Michael Yang's avatar
lint  
Michael Yang committed
144
				require.NoError(t, err)
145
146
147
148
149
150
151

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, "application/json", contentType)
				_, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
152
153
				require.NoError(t, err)
				assert.Equal(t, 200, resp.StatusCode)
154
155

				model, err := GetModel("t-bone")
Michael Yang's avatar
lint  
Michael Yang committed
156
				require.NoError(t, err)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
				assert.Equal(t, "t-bone:latest", model.ShortName)
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
Michael Yang's avatar
lint  
Michael Yang committed
171
				require.NoError(t, err)
172
173
174
175
176

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
Michael Yang's avatar
lint  
Michael Yang committed
177
				require.NoError(t, err)
178
				assert.Equal(t, "beefsteak:latest", model.ShortName)
179
180
			},
		},
Patrick Devine's avatar
Patrick Devine committed
181
182
183
184
185
186
187
188
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
Michael Yang's avatar
lint  
Michael Yang committed
189
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
190
191
192
193
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
Michael Yang's avatar
lint  
Michael Yang committed
194
				assert.Equal(t, "application/json; charset=utf-8", contentType)
Patrick Devine's avatar
Patrick Devine committed
195
				body, err := io.ReadAll(resp.Body)
Michael Yang's avatar
lint  
Michael Yang committed
196
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
197
198
199

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
Michael Yang's avatar
lint  
Michael Yang committed
200
				require.NoError(t, err)
Patrick Devine's avatar
Patrick Devine committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
				assert.Equal(t, expectedParams, params)
			},
		},
217
218
	}

219
	t.Setenv("OLLAMA_MODELS", t.TempDir())
220
	envconfig.LoadConfig()
221

222
	s := &Server{}
223
224
225
226
227
228
	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

	for _, tc := range testCases {
Michael Yang's avatar
Michael Yang committed
229
230
231
		t.Run(tc.Name, func(t *testing.T) {
			u := httpSrv.URL + tc.Path
			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
Michael Yang's avatar
lint  
Michael Yang committed
232
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
233
234
235
236
237
238

			if tc.Setup != nil {
				tc.Setup(t, req)
			}

			resp, err := httpSrv.Client().Do(req)
Michael Yang's avatar
lint  
Michael Yang committed
239
			require.NoError(t, err)
Michael Yang's avatar
Michael Yang committed
240
241
242
243
244
245
			defer resp.Body.Close()

			if tc.Expected != nil {
				tc.Expected(t, resp)
			}
		})
246
247
	}
}
248
249
250

func TestCase(t *testing.T) {
	t.Setenv("OLLAMA_MODELS", t.TempDir())
251
	envconfig.LoadConfig()
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

	cases := []string{
		"mistral",
		"llama3:latest",
		"library/phi3:q4_0",
		"registry.ollama.ai/library/gemma:q5_K_M",
		// TODO: host:port currently fails on windows (#4107)
		// "localhost:5000/alice/bob:latest",
	}

	var s Server
	for _, tt := range cases {
		t.Run(tt, func(t *testing.T) {
			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
				Name:      tt,
267
				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
				Stream:    &stream,
			})

			if w.Code != http.StatusOK {
				t.Fatalf("expected status 200 got %d", w.Code)
			}

			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
			if err != nil {
				t.Fatal(err)
			}

			t.Run("create", func(t *testing.T) {
				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
					Name:      strings.ToUpper(tt),
283
					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
					Stream:    &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("pull", func(t *testing.T) {
				w := createRequest(t, s.PullModelHandler, api.PullRequest{
					Name:   strings.ToUpper(tt),
					Stream: &stream,
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})

			t.Run("copy", func(t *testing.T) {
				w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
					Source:      tt,
					Destination: strings.ToUpper(tt),
				})

				if w.Code != http.StatusBadRequest {
					t.Fatalf("expected status 500 got %d", w.Code)
				}

				if !bytes.Equal(w.Body.Bytes(), expect) {
					t.Fatalf("expected error %s got %s", expect, w.Body.String())
				}
			})
		})
	}
}