utils_test.go 8.9 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
//go:build integration

package integration

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/app/lifecycle"
	"github.com/stretchr/testify/require"
)

func Init() {
	lifecycle.InitLogging()
}

func FindPort() string {
	port := 0
	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
		var l *net.TCPListener
		if l, err = net.ListenTCP("tcp", a); err == nil {
			port = l.Addr().(*net.TCPAddr).Port
			l.Close()
		}
	}
	if port == 0 {
		port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
	}
	return strconv.Itoa(port)
}

func GetTestEndpoint() (*api.Client, string) {
	defaultPort := "11434"
	ollamaHost := os.Getenv("OLLAMA_HOST")

	scheme, hostport, ok := strings.Cut(ollamaHost, "://")
	if !ok {
		scheme, hostport = "http", ollamaHost
	}

	// trim trailing slashes
	hostport = strings.TrimRight(hostport, "/")

	host, port, err := net.SplitHostPort(hostport)
	if err != nil {
		host, port = "127.0.0.1", defaultPort
		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
			host = ip.String()
		} else if hostport != "" {
			host = hostport
		}
	}

	if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
		port = FindPort()
	}

	slog.Info("server connection", "host", host, "port", port)

	return api.NewClient(
		&url.URL{
			Scheme: scheme,
			Host:   net.JoinHostPort(host, port),
		},
		http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
}

var serverMutex sync.Mutex
var serverReady bool

func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
	// Make sure the server has been built
	CLIName, err := filepath.Abs("../ollama")
	if err != nil {
		return err
	}

	if runtime.GOOS == "windows" {
		CLIName += ".exe"
	}
	_, err = os.Stat(CLIName)
	if err != nil {
		return fmt.Errorf("CLI missing, did you forget to build first?  %w", err)
	}
	serverMutex.Lock()
	defer serverMutex.Unlock()
	if serverReady {
		return nil
	}

	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
		t.Setenv("OLLAMA_HOST", ollamaHost)
	}

	slog.Info("starting server", "url", ollamaHost)
	done, err := lifecycle.SpawnServer(ctx, "../ollama")
	if err != nil {
		return fmt.Errorf("failed to start server: %w", err)
	}

	go func() {
		<-ctx.Done()
		serverMutex.Lock()
		defer serverMutex.Unlock()
		exitCode := <-done
		if exitCode > 0 {
			slog.Warn("server failure", "exit", exitCode)
		}
		serverReady = false
	}()

	// TODO wait only long enough for the server to be responsive...
	time.Sleep(500 * time.Millisecond)

	serverReady = true
	return nil
}

func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
	slog.Info("checking status of model", "model", modelName)
	showReq := &api.ShowRequest{Name: modelName}

	showCtx, cancel := context.WithDeadlineCause(
		ctx,
xuxzh1's avatar
init  
xuxzh1 committed
143
		time.Now().Add(10*time.Second),
mashun1's avatar
v1  
mashun1 committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
		fmt.Errorf("show for existing model %s took too long", modelName),
	)
	defer cancel()
	_, err := client.Show(showCtx, showReq)
	var statusError api.StatusError
	switch {
	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
		break
	case err != nil:
		return err
	default:
		slog.Info("model already present", "model", modelName)
		return nil
	}
	slog.Info("model missing", "model", modelName)

	stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
	stallTimer := time.NewTimer(stallDuration)
	fn := func(resp api.ProgressResponse) error {
		// fmt.Print(".")
		if !stallTimer.Reset(stallDuration) {
xuxzh1's avatar
init  
xuxzh1 committed
165
			return errors.New("stall was detected, aborting status reporting")
mashun1's avatar
v1  
mashun1 committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
		}
		return nil
	}

	stream := true
	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}

	var pullError error

	done := make(chan int)
	go func() {
		pullError = client.Pull(ctx, pullReq, fn)
		done <- 0
	}()

	select {
	case <-stallTimer.C:
xuxzh1's avatar
init  
xuxzh1 committed
183
		return errors.New("download stalled")
mashun1's avatar
v1  
mashun1 committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
	case <-done:
		return pullError
	}
}

var serverProcMutex sync.Mutex

// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
	client, testEndpoint := GetTestEndpoint()
	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
		serverProcMutex.Lock()
		fp, err := os.CreateTemp("", "ollama-server-*.log")
		if err != nil {
			t.Fatalf("failed to generate log file: %s", err)
		}
		lifecycle.ServerLogFile = fp.Name()
		fp.Close()
		require.NoError(t, startServer(t, ctx, testEndpoint))
	}

	return client, testEndpoint, func() {
		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
			defer serverProcMutex.Unlock()
			if t.Failed() {
				fp, err := os.Open(lifecycle.ServerLogFile)
				if err != nil {
					slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
					return
				}
				data, err := io.ReadAll(fp)
				if err != nil {
					slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
					return
				}
				slog.Warn("SERVER LOG FOLLOWS")
				os.Stderr.Write(data)
				slog.Warn("END OF SERVER")
			}
			err := os.Remove(lifecycle.ServerLogFile)
			if err != nil && !os.IsNotExist(err) {
				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
			}
		}
	}
}

func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()
	require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
	DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}

func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
	stallTimer := time.NewTimer(initialTimeout)
	var buf bytes.Buffer
	fn := func(response api.GenerateResponse) error {
		// fmt.Print(".")
		buf.Write([]byte(response.Response))
		if !stallTimer.Reset(streamTimeout) {
xuxzh1's avatar
init  
xuxzh1 committed
246
			return errors.New("stall was detected while streaming response, aborting")
mashun1's avatar
v1  
mashun1 committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
		}
		return nil
	}

	stream := true
	genReq.Stream = &stream
	done := make(chan int)
	var genErr error
	go func() {
		genErr = client.Generate(ctx, &genReq, fn)
		done <- 0
	}()

	select {
	case <-stallTimer.C:
		if buf.Len() == 0 {
			t.Errorf("generate never started.  Timed out after :%s", initialTimeout.String())
		} else {
			t.Errorf("generate stalled.  Response so far:%s", buf.String())
		}
	case <-done:
		require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
		// Verify the response contains the expected data
		response := buf.String()
		atLeastOne := false
		for _, resp := range anyResp {
			if strings.Contains(strings.ToLower(response), resp) {
				atLeastOne = true
				break
			}
		}
		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
	case <-ctx.Done():
		t.Error("outer test context done while waiting for generate")
	}
}

// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
	return []api.GenerateRequest{
			{
xuxzh1's avatar
init  
xuxzh1 committed
290
291
292
293
				Model:     "orca-mini",
				Prompt:    "why is the ocean blue?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
mashun1's avatar
v1  
mashun1 committed
294
295
296
297
298
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			}, {
xuxzh1's avatar
init  
xuxzh1 committed
299
300
301
302
				Model:     "orca-mini",
				Prompt:    "why is the color of dirt brown?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
mashun1's avatar
v1  
mashun1 committed
303
304
305
306
307
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			}, {
xuxzh1's avatar
init  
xuxzh1 committed
308
309
310
311
				Model:     "orca-mini",
				Prompt:    "what is the origin of the us thanksgiving holiday?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
mashun1's avatar
v1  
mashun1 committed
312
313
314
315
316
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			}, {
xuxzh1's avatar
init  
xuxzh1 committed
317
318
319
320
				Model:     "orca-mini",
				Prompt:    "what is the origin of independence day?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
mashun1's avatar
v1  
mashun1 committed
321
322
323
324
325
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			}, {
xuxzh1's avatar
init  
xuxzh1 committed
326
327
328
329
				Model:     "orca-mini",
				Prompt:    "what is the composition of air?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
mashun1's avatar
v1  
mashun1 committed
330
331
332
333
334
335
336
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			},
		},
		[][]string{
xuxzh1's avatar
init  
xuxzh1 committed
337
338
339
340
341
			{"sunlight"},
			{"soil", "organic", "earth", "black", "tan"},
			{"england", "english", "massachusetts", "pilgrims", "british"},
			{"fourth", "july", "declaration", "independence"},
			{"nitrogen", "oxygen", "carbon", "dioxide"},
mashun1's avatar
v1  
mashun1 committed
342
343
		}
}