routes_test.go 3.03 KB
Newer Older
1
2
3
4
package server

import (
	"context"
5
6
	"encoding/json"
	"fmt"
7
8
9
	"io"
	"net/http"
	"net/http/httptest"
10
11
	"os"
	"strings"
12
13
14
	"testing"

	"github.com/stretchr/testify/assert"
15
16
17

	"github.com/jmorganca/ollama/api"
	"github.com/jmorganca/ollama/parser"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
)

func setupServer(t *testing.T) (*Server, error) {
	t.Helper()

	return NewServer()
}

func Test_Routes(t *testing.T) {
	type testCase struct {
		Name     string
		Method   string
		Path     string
		Setup    func(t *testing.T, req *http.Request)
		Expected func(t *testing.T, resp *http.Response)
	}

	testCases := []testCase{
		{
			Name:   "Version Handler",
			Method: http.MethodGet,
			Path:   "/api/version",
			Setup: func(t *testing.T, req *http.Request) {
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)
				assert.Equal(t, `{"version":"0.0.0"}`, string(body))
			},
		},
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
		{
			Name:   "Tags Handler (no tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
				assert.Nil(t, err)

				assert.Equal(t, 0, len(modelList.Models))
			},
		},
		{
			Name:   "Tags Handler (yes tags)",
			Method: http.MethodGet,
			Path:   "/api/tags",
			Setup: func(t *testing.T, req *http.Request) {
				f, err := os.CreateTemp("", "ollama-modelfile")
				assert.Nil(t, err)
				defer os.RemoveAll(f.Name())

				modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name()))
				commands, err := parser.Parse(modelfile)
				assert.Nil(t, err)
				fn := func(resp api.ProgressResponse) {}
				err = CreateModel(context.TODO(), "test", "", commands, fn)
				assert.Nil(t, err)
			},
			Expected: func(t *testing.T, resp *http.Response) {
				contentType := resp.Header.Get("Content-Type")
				assert.Equal(t, contentType, "application/json; charset=utf-8")
				body, err := io.ReadAll(resp.Body)
				assert.Nil(t, err)

				var modelList api.ListResponse

				err = json.Unmarshal(body, &modelList)
				assert.Nil(t, err)

				assert.Equal(t, 1, len(modelList.Models))
				assert.Equal(t, modelList.Models[0].Name, "test:latest")
			},
		},
99
100
101
102
103
104
105
106
107
108
	}

	s, err := setupServer(t)
	assert.Nil(t, err)

	router := s.GenerateRoutes()

	httpSrv := httptest.NewServer(router)
	t.Cleanup(httpSrv.Close)

109
110
111
112
113
	workDir, err := os.MkdirTemp("", "ollama-test")
	assert.Nil(t, err)
	defer os.RemoveAll(workDir)
	os.Setenv("OLLAMA_MODELS", workDir)

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
	for _, tc := range testCases {
		u := httpSrv.URL + tc.Path
		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
		assert.Nil(t, err)

		if tc.Setup != nil {
			tc.Setup(t, req)
		}

		resp, err := httpSrv.Client().Do(req)
		assert.Nil(t, err)

		if tc.Expected != nil {
			tc.Expected(t, resp)
		}
	}

}