server_test.go 8.5 KB
Newer Older
1
2
3
package registry

import (
4
5
	"bytes"
	"context"
6
	"encoding/json"
7
8
9
	"io"
	"io/fs"
	"net"
10
11
12
13
14
	"net/http"
	"net/http/httptest"
	"os"
	"regexp"
	"strings"
15
	"sync"
16
17
18
19
20
	"testing"

	"github.com/ollama/ollama/server/internal/cache/blob"
	"github.com/ollama/ollama/server/internal/client/ollama"
	"github.com/ollama/ollama/server/internal/testutil"
21
22
23
	"golang.org/x/tools/txtar"

	_ "embed"
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
)

type panicTransport struct{}

func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
	panic("unexpected RoundTrip call")
}

var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}

// bytesResetter is an interface for types that can be reset and return a byte
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
// etc for the purpose of checking logs.
type bytesResetter interface {
	Bytes() []byte
	Reset()
}

42
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
43
44
45
46
47
48
49
50
51
52
	t.Helper()
	dir := t.TempDir()
	err := os.CopyFS(dir, os.DirFS("testdata/models"))
	if err != nil {
		t.Fatal(err)
	}
	c, err := blob.Open(dir)
	if err != nil {
		t.Fatal(err)
	}
53
54
55
56
57
58
59
60
61
62
63
64
65

	client := panicOnRoundTrip
	if upstreamRegistry != nil {
		s := httptest.NewTLSServer(upstreamRegistry)
		t.Cleanup(s.Close)
		tr := s.Client().Transport.(*http.Transport).Clone()
		tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
			var d net.Dialer
			return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
		}
		client = &http.Client{Transport: tr}
	}

66
	rc := &ollama.Registry{
67
		Cache:      c,
68
69
		HTTPClient: client,
		Mask:       "example.com/library/_:latest",
70
	}
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
	l := &Local{
		Client: rc,
		Logger: testutil.Slogger(t),
	}
	return l
}

func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
	t.Helper()
	req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
	return s.sendRequest(t, req)
}

func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
	t.Helper()
	w := httptest.NewRecorder()
	s.ServeHTTP(w, req)
	return w
}

type invalidReader struct{}

func (r *invalidReader) Read(p []byte) (int, error) {
	return 0, os.ErrInvalid
}

// captureLogs is a helper to capture logs from the server. It returns a
// shallow copy of the server with a new logger and a bytesResetter for the
// logs.
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
	t.Helper()
	log, logs := testutil.SlogBuffer()
	l := *s // shallow copy
	l.Logger = log
	return &l, logs
}

func TestServerDelete(t *testing.T) {
	check := testutil.Checker(t)

112
	s := newTestServer(t, nil)
113

114
	_, err := s.Client.ResolveLocal("smol")
115
116
117
118
119
120
121
	check(err)

	got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
	if got.Code != 200 {
		t.Fatalf("Code = %d; want 200", got.Code)
	}

122
	_, err = s.Client.ResolveLocal("smol")
123
124
125
126
127
128
129
130
131
132
133
134
135
136
	if err == nil {
		t.Fatal("expected smol to have been deleted")
	}

	got = s.send(t, "DELETE", "/api/delete", `!`)
	checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")

	got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
	checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")

	got = s.send(t, "DELETE", "/api/delete", ``)
	checkErrorResponse(t, got, 400, "bad_request", "empty request body")

	got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
137
	checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

	got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
	checkErrorResponse(t, got, 404, "not_found", "not found")

	s, logs := captureLogs(t, s)
	req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
	got = s.sendRequest(t, req)
	checkErrorResponse(t, got, 500, "internal_error", "internal server error")
	ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
	check(err)
	if !ok {
		t.Logf("logs:\n%s", logs)
		t.Fatalf("expected log to contain ERROR with invalid argument")
	}
}

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//go:embed testdata/registry.txt
var registryTXT []byte

var registryFS = sync.OnceValue(func() fs.FS {
	// Txtar gets hung up on \r\n line endings, so we need to convert them
	// to \n when parsing the txtar on Windows.
	data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
	a := txtar.Parse(data)
	fsys, err := txtar.FS(a)
	if err != nil {
		panic(err)
	}
	return fsys
})

func TestServerPull(t *testing.T) {
	modelsHandler := http.FileServerFS(registryFS())
	s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
		switch r.URL.Path {
		case "/v2/library/BOOM/manifests/latest":
			w.WriteHeader(999)
			io.WriteString(w, `{"error": "boom"}`)
		case "/v2/library/unknown/manifests/latest":
			w.WriteHeader(404)
			io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
		default:
180
			t.Logf("serving blob: %s", r.URL.Path)
181
182
183
184
185
186
187
188
			modelsHandler.ServeHTTP(w, r)
		}
	})

	checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
		t.Helper()

		if got.Code != 200 {
189
			t.Errorf("Code = %d; want 200", got.Code)
190
191
192
193
194
195
196
197
		}
		gotlines := got.Body.String()
		t.Logf("got:\n%s", gotlines)
		for want := range strings.Lines(wantlines) {
			want = strings.TrimSpace(want)
			want, unwanted := strings.CutPrefix(want, "!")
			want = strings.TrimSpace(want)
			if !unwanted && !strings.Contains(gotlines, want) {
198
				t.Errorf("! missing %q in body", want)
199
200
			}
			if unwanted && strings.Contains(gotlines, want) {
201
				t.Errorf("! unexpected %q in body", want)
202
203
204
205
206
207
208
209
210
211
212
			}
		}
	}

	got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
	checkResponse(got, `
		{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
	`)

	got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
	checkResponse(got, `
213
214
215
216
		{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
		{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
		{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
		{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
	`)

	got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
	checkResponse(got, `
		{"status":"error: model \"unknown\" not found"}
	`)

	got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
	checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")

	got = s.send(t, "POST", "/api/pull", `!`)
	checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")

	got = s.send(t, "POST", "/api/pull", ``)
	checkErrorResponse(t, got, 400, "bad_request", "empty request body")

	got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
	checkResponse(got, `
		{"status":"error: invalid or missing name: \"\""}
236
	`)
237

238
239
240
241
242
243
244
245
246
	// Non-streaming pulls
	got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
	checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
	got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
	checkResponse(got, `
		{"status":"success"}
		!digest
		!total
		!completed
247
	`)
248
249
	got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
	checkErrorResponse(t, got, 404, "not_found", "model not found")
250
251
}

252
func TestServerUnknownPath(t *testing.T) {
253
	s := newTestServer(t, nil)
254
255
	got := s.send(t, "DELETE", "/api/unknown", `{}`)
	checkErrorResponse(t, got, 404, "not_found", "not found")
256
257
258
259
260
261
262
263
264
265
266
267

	var fellback bool
	s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fellback = true
	})
	got = s.send(t, "DELETE", "/api/unknown", `{}`)
	if !fellback {
		t.Fatal("expected Fallback to be called")
	}
	if got.Code != 200 {
		t.Fatalf("Code = %d; want 200", got.Code)
	}
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
}

func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
	t.Helper()

	var printedBody bool
	errorf := func(format string, args ...any) {
		t.Helper()
		if !printedBody {
			t.Logf("BODY:\n%s", got.Body.String())
			printedBody = true
		}
		t.Errorf(format, args...)
	}

	if got.Code != status {
		errorf("Code = %d; want %d", got.Code, status)
	}

	// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
	var e *ollama.Error
	if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
		errorf("unmarshal error: %v", err)
		t.FailNow()
	}
	if e.Code != code {
		errorf("Code = %q; want %q", e.Code, code)
	}
	if !strings.Contains(e.Message, msg) {
		errorf("Message = %q; want to contain %q", e.Message, msg)
	}
}