max_queue_test.go 3.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//go:build integration

package integration

import (
	"context"
	"errors"
	"log/slog"
	"os"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

Michael Yang's avatar
int  
Michael Yang committed
16
	"github.com/ollama/ollama/api"
17
18
19
)

func TestMaxQueue(t *testing.T) {
20
21
	t.Skip("this test needs to be re-evaluated to use a proper embedding model")

Daniel Hiltgen's avatar
Daniel Hiltgen committed
22
	if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
23
		t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
24
25
26
		return
	}

27
28
	// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
	// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
29
30
	threadCount := 16
	t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
31
32

	req := api.GenerateRequest{
33
		Model:  smol,
34
		Prompt: "write a long historical fiction story about christopher columbus.  use at least 10 facts from his actual journey",
35
		Options: map[string]any{
36
37
38
39
40
41
42
43
44
45
46
47
			"seed":        42,
			"temperature": 0.0,
		},
	}
	resp := []string{"explore", "discover", "ocean"}

	// CPU mode takes much longer at the limit with a large queue setting
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

48
49
50
	if err := PullIfMissing(ctx, client, req.Model); err != nil {
		t.Fatal(err)
	}
51
52
53
54
55
56

	// Context for the worker threads so we can shut them down
	// embedCtx, embedCancel := context.WithCancel(ctx)
	embedCtx := ctx

	var genwg sync.WaitGroup
57
	genwg.Add(1)
58
59
60
61
62
63
64
65
	go func() {
		defer genwg.Done()
		slog.Info("Starting generate request")
		DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
		slog.Info("generate completed")
	}()

	// Give the generate a chance to get started before we start hammering on embed requests
66
	time.Sleep(10 * time.Millisecond)
67
68
69
70
71

	threadCount += 10 // Add a few extra to ensure we push the queue past its limit
	busyCount := 0
	resetByPeerCount := 0
	canceledCount := 0
72
	successCount := 0
73
74
75
	counterMu := sync.Mutex{}
	var embedwg sync.WaitGroup
	for i := 0; i < threadCount; i++ {
76
		embedwg.Add(1)
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
		go func(i int) {
			defer embedwg.Done()
			slog.Info("embed started", "id", i)
			embedReq := api.EmbeddingRequest{
				Model:   req.Model,
				Prompt:  req.Prompt,
				Options: req.Options,
			}
			// Fresh client for every request
			client, _ = GetTestEndpoint()

			resp, genErr := client.Embeddings(embedCtx, &embedReq)
			counterMu.Lock()
			defer counterMu.Unlock()
			switch {
			case genErr == nil:
93
				successCount++
94
95
96
				if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
					t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
				}
97
98
99
100
101
102
103
			case errors.Is(genErr, context.Canceled):
				canceledCount++
			case strings.Contains(genErr.Error(), "busy"):
				busyCount++
			case strings.Contains(genErr.Error(), "connection reset by peer"):
				resetByPeerCount++
			default:
104
105
106
				if genErr != nil {
					t.Fatalf("%d request failed", i)
				}
107
108
109
110
111
112
113
114
115
			}

			slog.Info("embed finished", "id", i)
		}(i)
	}
	genwg.Wait()
	slog.Info("generate done, waiting for embeds")
	embedwg.Wait()

116
	slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
117
118
119
120
121
122
123
124
125
	if resetByPeerCount != 0 {
		t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
	}
	if busyCount == 0 {
		t.Fatalf("no requests hit busy error but some should have")
	}
	if canceledCount > 0 {
		t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
	}
126
}