max_queue_test.go 3.47 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
//go:build integration

package integration

import (
	"context"
	"errors"
	"log/slog"
	"os"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
xuxzh1's avatar
init  
xuxzh1 committed
17
18
19

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/envconfig"
mashun1's avatar
v1  
mashun1 committed
20
21
22
23
24
25
26
27
28
29
30
)

func TestMaxQueue(t *testing.T) {
	if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
		t.Skip("Max Queue test requires spawing a local server so we can adjust the queue size")
		return
	}

	// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
	// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
	threadCount := 32
xuxzh1's avatar
init  
xuxzh1 committed
31
32
	if maxQueue := envconfig.MaxQueue(); maxQueue != 0 {
		threadCount = int(maxQueue)
mashun1's avatar
v1  
mashun1 committed
33
	} else {
xuxzh1's avatar
init  
xuxzh1 committed
34
		t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
mashun1's avatar
v1  
mashun1 committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
	}

	req := api.GenerateRequest{
		Model:  "orca-mini",
		Prompt: "write a long historical fiction story about christopher columbus.  use at least 10 facts from his actual journey",
		Options: map[string]interface{}{
			"seed":        42,
			"temperature": 0.0,
		},
	}
	resp := []string{"explore", "discover", "ocean"}

	// CPU mode takes much longer at the limit with a large queue setting
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	require.NoError(t, PullIfMissing(ctx, client, req.Model))

	// Context for the worker threads so we can shut them down
	// embedCtx, embedCancel := context.WithCancel(ctx)
	embedCtx := ctx

	var genwg sync.WaitGroup
	go func() {
		genwg.Add(1)
		defer genwg.Done()
		slog.Info("Starting generate request")
		DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
		slog.Info("generate completed")
	}()

	// Give the generate a chance to get started before we start hammering on embed requests
	time.Sleep(5 * time.Millisecond)

	threadCount += 10 // Add a few extra to ensure we push the queue past its limit
	busyCount := 0
	resetByPeerCount := 0
	canceledCount := 0
	succesCount := 0
	counterMu := sync.Mutex{}
	var embedwg sync.WaitGroup
	for i := 0; i < threadCount; i++ {
		go func(i int) {
			embedwg.Add(1)
			defer embedwg.Done()
			slog.Info("embed started", "id", i)
			embedReq := api.EmbeddingRequest{
				Model:   req.Model,
				Prompt:  req.Prompt,
				Options: req.Options,
			}
			// Fresh client for every request
			client, _ = GetTestEndpoint()

			resp, genErr := client.Embeddings(embedCtx, &embedReq)
			counterMu.Lock()
			defer counterMu.Unlock()
			switch {
			case genErr == nil:
				succesCount++
				require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
			case errors.Is(genErr, context.Canceled):
				canceledCount++
			case strings.Contains(genErr.Error(), "busy"):
				busyCount++
			case strings.Contains(genErr.Error(), "connection reset by peer"):
				resetByPeerCount++
			default:
				require.NoError(t, genErr, "%d request failed", i)
			}

			slog.Info("embed finished", "id", i)
		}(i)
	}
	genwg.Wait()
	slog.Info("generate done, waiting for embeds")
	embedwg.Wait()

	slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
	require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
	require.True(t, busyCount > 0, "no requests hit busy error but some should have")
	require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")

}