routes_test.go 6.67 KB
Newer Older
1
2
3
package server

import (
4
	"bytes"
5
	"context"
6
7
	"encoding/json"
	"fmt"
8
9
10
	"io"
	"net/http"
	"net/http/httptest"
11
	"os"
Patrick Devine's avatar
Patrick Devine committed
12
	"sort"
13
	"strings"
14
15
16
	"testing"

	"github.com/stretchr/testify/assert"
17
18

	"github.com/jmorganca/ollama/api"
19
	"github.com/jmorganca/ollama/llm"
20
	"github.com/jmorganca/ollama/parser"
21
	"github.com/jmorganca/ollama/version"
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
)

func setupServer(t *testing.T) (*Server, error) {
	t.Helper()

	return NewServer()
}

func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
	}
38

39
40
41
42
43
44
45
46
	createTestFile := func(t *testing.T, name string) string {
		f, err := os.CreateTemp(t.TempDir(), name)
		assert.Nil(t, err)
		defer f.Close()

		_, err = f.Write([]byte("GGUF"))
		assert.Nil(t, err)
		_, err = f.Write([]byte{0x2, 0})
47
48
		assert.Nil(t, err)

49
50
51
52
53
54
		return f.Name()
	}

	createTestModel := func(t *testing.T, name string) {
		fname := createTestFile(t, "ollama-model")

Patrick Devine's avatar
Patrick Devine committed
55
		modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
56
57
		commands, err := parser.Parse(modelfile)
		assert.Nil(t, err)
58
59
60
		fn := func(resp api.ProgressResponse) {
			t.Logf("Status: %s", resp.Status)
		}
61
62
63
		err = CreateModel(context.TODO(), name, "", commands, fn)
		assert.Nil(t, err)
	}
64
65
66
67
68
69
70
71
72
73
74
75
76

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)
77
				assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
78
79
			},
		},
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
				assert.Nil(t, err)

				assert.Equal(t, 0, len(modelList.Models))
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
103
				createTestModel(t, "test-model")
104
105
106
107
108
109
110
111
112
113
114
115
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)

				var modelList api.ListResponse
				err = json.Unmarshal(body, &modelList)
				assert.Nil(t, err)

				assert.Equal(t, 1, len(modelList.Models))
116
117
118
119
120
121
122
123
				assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
			},
		},
		{
			Name:   "Create Model Handler",
			Method: http.MethodPost,
			Path:   "/api/create",
			Setup: func(t *testing.T, req *http.Request) {
124
				f, err := os.CreateTemp(t.TempDir(), "ollama-model")
125
				assert.Nil(t, err)
126
				defer f.Close()
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

				stream := false
				createReq := api.CreateRequest{
					Name:      "t-bone",
					Modelfile: fmt.Sprintf("FROM %s", f.Name()),
					Stream:    &stream,
				}
				jsonData, err := json.Marshal(createReq)
				assert.Nil(t, err)

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, "application/json", contentType)
				_, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)
				assert.Equal(t, resp.StatusCode, 200)

				model, err := GetModel("t-bone")
				assert.Nil(t, err)
				assert.Equal(t, "t-bone:latest", model.ShortName)
			},
		},
		{
			Name:   "Copy Model Handler",
			Method: http.MethodPost,
			Path:   "/api/copy",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "hamshank")
				copyReq := api.CopyRequest{
					Source:      "hamshank",
					Destination: "beefsteak",
				}
				jsonData, err := json.Marshal(copyReq)
				assert.Nil(t, err)

				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				model, err := GetModel("beefsteak")
				assert.Nil(t, err)
				assert.Equal(t, "beefsteak:latest", model.ShortName)
170
171
			},
		},
Patrick Devine's avatar
Patrick Devine committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
		{
			Name:   "Show Model Handler",
			Method: http.MethodPost,
			Path:   "/api/show",
			Setup: func(t *testing.T, req *http.Request) {
				createTestModel(t, "show-model")
				showReq := api.ShowRequest{Model: "show-model"}
				jsonData, err := json.Marshal(showReq)
				assert.Nil(t, err)
				req.Body = io.NopCloser(bytes.NewReader(jsonData))
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)

				var showResp api.ShowResponse
				err = json.Unmarshal(body, &showResp)
				assert.Nil(t, err)

				var params []string
				paramsSplit := strings.Split(showResp.Parameters, "\n")
				for _, p := range paramsSplit {
					params = append(params, strings.Join(strings.Fields(p), " "))
				}
				sort.Strings(params)
				expectedParams := []string{
					"seed 42",
					"stop \"bar\"",
					"stop \"foo\"",
					"top_p 0.9",
				}
				assert.Equal(t, expectedParams, params)
			},
		},
208
209
210
211
212
213
214
215
216
217
	}

	s, err := setupServer(t)
	assert.Nil(t, err)

	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

218
219
220
221
222
	workDir, err := os.MkdirTemp("", "ollama-test")
	assert.Nil(t, err)
	defer os.RemoveAll(workDir)
	os.Setenv("OLLAMA_MODELS", workDir)

223
	for _, tc := range testCases {
224
		t.Logf("Running Test: [%s]", tc.Name)
225
226
227
228
229
230
231
232
233
234
		u := httpSrv.URL + tc.Path
		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
		assert.Nil(t, err)

		if tc.Setup != nil {
			tc.Setup(t, req)
		}

		resp, err := httpSrv.Client().Do(req)
		assert.Nil(t, err)
Michael Yang's avatar
Michael Yang committed
235
		defer resp.Body.Close()
236
237
238
239

		if tc.Expected != nil {
			tc.Expected(t, resp)
		}
240

241
242
	}
}
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

type MockLLM struct {
	encoding []int
}

func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
	return nil
}

func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
	return llm.encoding, nil
}

func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
	return "", nil
}

func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
	return []float64{}, nil
}

func (llm *MockLLM) Close() {
	// do nothing
}