server_test.go 8.36 KB
Newer Older
1
2
3
package registry

import (
4
5
	"bytes"
	"context"
6
	"encoding/json"
7
8
9
	"io"
	"io/fs"
	"net"
10
11
12
13
14
	"net/http"
	"net/http/httptest"
	"os"
	"regexp"
	"strings"
15
	"sync"
16
17
18
19
20
	"testing"

	"github.com/ollama/ollama/server/internal/cache/blob"
	"github.com/ollama/ollama/server/internal/client/ollama"
	"github.com/ollama/ollama/server/internal/testutil"
21
22
23
	"golang.org/x/tools/txtar"

	_ "embed"
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
)

type panicTransport struct{}

func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
	panic("unexpected RoundTrip call")
}

var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}

// bytesResetter is an interface for types that can be reset and return a byte
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
// etc for the purpose of checking logs.
type bytesResetter interface {
	Bytes() []byte
	Reset()
}

42
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
43
44
45
46
47
48
49
50
51
52
	t.Helper()
	dir := t.TempDir()
	err := os.CopyFS(dir, os.DirFS("testdata/models"))
	if err != nil {
		t.Fatal(err)
	}
	c, err := blob.Open(dir)
	if err != nil {
		t.Fatal(err)
	}
53
54
55
56
57
58
59
60
61
62
63
64
65

	client := panicOnRoundTrip
	if upstreamRegistry != nil {
		s := httptest.NewTLSServer(upstreamRegistry)
		t.Cleanup(s.Close)
		tr := s.Client().Transport.(*http.Transport).Clone()
		tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
			var d net.Dialer
			return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
		}
		client = &http.Client{Transport: tr}
	}

66
	rc := &ollama.Registry{
67
		Cache:      c,
68
69
		HTTPClient: client,
		Mask:       "example.com/library/_:latest",
70
	}
71

72
73
74
75
76
77
78
79
80
	l := &Local{
		Client: rc,
		Logger: testutil.Slogger(t),
	}
	return l
}

func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
	t.Helper()
81
82
83
84
85
86
	ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
		Update: func(l *ollama.Layer, n int64, err error) {
			t.Logf("update: %s %d %v", l.Digest, n, err)
		},
	})
	req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
	return s.sendRequest(t, req)
}

func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
	t.Helper()
	w := httptest.NewRecorder()
	s.ServeHTTP(w, req)
	return w
}

type invalidReader struct{}

func (r *invalidReader) Read(p []byte) (int, error) {
	return 0, os.ErrInvalid
}

// captureLogs is a helper to capture logs from the server. It returns a
// shallow copy of the server with a new logger and a bytesResetter for the
// logs.
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
	t.Helper()
	log, logs := testutil.SlogBuffer()
	l := *s // shallow copy
	l.Logger = log
	return &l, logs
}

func TestServerDelete(t *testing.T) {
	check := testutil.Checker(t)

117
	s := newTestServer(t, nil)
118

119
	_, err := s.Client.ResolveLocal("smol")
120
121
122
123
124
125
126
	check(err)

	got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
	if got.Code != 200 {
		t.Fatalf("Code = %d; want 200", got.Code)
	}

127
	_, err = s.Client.ResolveLocal("smol")
128
129
130
131
132
133
134
135
136
137
138
139
140
141
	if err == nil {
		t.Fatal("expected smol to have been deleted")
	}

	got = s.send(t, "DELETE", "/api/delete", `!`)
	checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")

	got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
	checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")

	got = s.send(t, "DELETE", "/api/delete", ``)
	checkErrorResponse(t, got, 400, "bad_request", "empty request body")

	got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
142
	checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

	got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
	checkErrorResponse(t, got, 404, "not_found", "not found")

	s, logs := captureLogs(t, s)
	req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
	got = s.sendRequest(t, req)
	checkErrorResponse(t, got, 500, "internal_error", "internal server error")
	ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
	check(err)
	if !ok {
		t.Logf("logs:\n%s", logs)
		t.Fatalf("expected log to contain ERROR with invalid argument")
	}
}

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
//go:embed testdata/registry.txt
var registryTXT []byte

var registryFS = sync.OnceValue(func() fs.FS {
	// Txtar gets hung up on \r\n line endings, so we need to convert them
	// to \n when parsing the txtar on Windows.
	data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
	a := txtar.Parse(data)
	fsys, err := txtar.FS(a)
	if err != nil {
		panic(err)
	}
	return fsys
})

func TestServerPull(t *testing.T) {
	modelsHandler := http.FileServerFS(registryFS())
	s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
		switch r.URL.Path {
		case "/v2/library/BOOM/manifests/latest":
			w.WriteHeader(999)
			io.WriteString(w, `{"error": "boom"}`)
		case "/v2/library/unknown/manifests/latest":
			w.WriteHeader(404)
			io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
		default:
185
			t.Logf("serving blob: %s", r.URL.Path)
186
187
188
189
190
191
192
			modelsHandler.ServeHTTP(w, r)
		}
	})

	checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
		t.Helper()
		if got.Code != 200 {
193
			t.Errorf("Code = %d; want 200", got.Code)
194
195
		}
		gotlines := got.Body.String()
196
197
198
		if strings.TrimSpace(gotlines) == "" {
			gotlines = "<empty>"
		}
199
200
201
202
203
204
		t.Logf("got:\n%s", gotlines)
		for want := range strings.Lines(wantlines) {
			want = strings.TrimSpace(want)
			want, unwanted := strings.CutPrefix(want, "!")
			want = strings.TrimSpace(want)
			if !unwanted && !strings.Contains(gotlines, want) {
205
				t.Errorf("\t! missing %q in body", want)
206
207
			}
			if unwanted && strings.Contains(gotlines, want) {
208
				t.Errorf("\t! unexpected %q in body", want)
209
210
211
212
			}
		}
	}

213
	got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
214
	checkResponse(got, `
215
		{"status":"pulling manifest"}
216
		{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
217
218
219
		{"status":"verifying sha256 digest"}
		{"status":"writing manifest"}
		{"status":"success"}
220
221
222
223
	`)

	got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
	checkResponse(got, `
224
		{"code":"not_found","error":"model \"unknown\" not found"}
225
226
227
228
229
230
231
232
233
234
235
236
237
	`)

	got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
	checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")

	got = s.send(t, "POST", "/api/pull", `!`)
	checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")

	got = s.send(t, "POST", "/api/pull", ``)
	checkErrorResponse(t, got, 400, "bad_request", "empty request body")

	got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
	checkResponse(got, `
238
		{"code":"bad_request","error":"invalid or missing name: \"\""}
239
	`)
240

241
242
243
244
245
246
247
248
249
	// Non-streaming pulls
	got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
	checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
	got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
	checkResponse(got, `
		{"status":"success"}
		!digest
		!total
		!completed
250
	`)
251
252
	got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
	checkErrorResponse(t, got, 404, "not_found", "model not found")
253
254
}

255
func TestServerUnknownPath(t *testing.T) {
256
	s := newTestServer(t, nil)
257
258
	got := s.send(t, "DELETE", "/api/unknown", `{}`)
	checkErrorResponse(t, got, 404, "not_found", "not found")
259
260
261
262
263
264
265
266
267
268
269
270

	var fellback bool
	s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fellback = true
	})
	got = s.send(t, "DELETE", "/api/unknown", `{}`)
	if !fellback {
		t.Fatal("expected Fallback to be called")
	}
	if got.Code != 200 {
		t.Fatalf("Code = %d; want 200", got.Code)
	}
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
}

func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
	t.Helper()

	var printedBody bool
	errorf := func(format string, args ...any) {
		t.Helper()
		if !printedBody {
			t.Logf("BODY:\n%s", got.Body.String())
			printedBody = true
		}
		t.Errorf(format, args...)
	}

	if got.Code != status {
		errorf("Code = %d; want %d", got.Code, status)
	}

	// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
	var e *ollama.Error
	if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
		errorf("unmarshal error: %v", err)
		t.FailNow()
	}
	if e.Code != code {
		errorf("Code = %q; want %q", e.Code, code)
	}
	if !strings.Contains(e.Message, msg) {
		errorf("Message = %q; want to contain %q", e.Message, msg)
	}
}