concurrency_test.go 6.23 KB
Newer Older
wangkx1's avatar
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
//go:build integration

package integration

import (
	"context"
	"log/slog"
	"os"
	"strconv"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/format"
)

func TestMultiModelConcurrency(t *testing.T) {
	var (
		req = [2]api.GenerateRequest{
			{
				Model:     "orca-mini",
				Prompt:    "why is the ocean blue?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			}, {
				Model:     "tinydolphin",
				Prompt:    "what is the origin of the us thanksgiving holiday?",
				Stream:    &stream,
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
				Options: map[string]interface{}{
					"seed":        42,
					"temperature": 0.0,
				},
			},
		}
		resp = [2][]string{
			{"sunlight"},
			{"england", "english", "massachusetts", "pilgrims", "british"},
		}
	)
	var wg sync.WaitGroup
	wg.Add(len(req))
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
	defer cancel()

	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	for i := 0; i < len(req); i++ {
		require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
	}

	for i := 0; i < len(req); i++ {
		go func(i int) {
			defer wg.Done()
			DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
		}(i)
	}
	wg.Wait()
}

func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
	req, resp := GenerateRequests()
	reqLimit := len(req)
	iterLimit := 5

	if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
		maxVram, err := strconv.ParseUint(s, 10, 64)
		require.NoError(t, err)
		// Don't hammer on small VRAM cards...
		if maxVram < 4*format.GibiByte {
			reqLimit = min(reqLimit, 2)
			iterLimit = 2
		}
	}

	ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	// Get the server running (if applicable) warm the model up with a single initial request
	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)

	var wg sync.WaitGroup
	wg.Add(reqLimit)
	for i := 0; i < reqLimit; i++ {
		go func(i int) {
			defer wg.Done()
			for j := 0; j < iterLimit; j++ {
				slog.Info("Starting", "req", i, "iter", j)
				// On slower GPUs it can take a while to process the concurrent requests
				// so we allow a much longer initial timeout
				DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
			}
		}(i)
	}
	wg.Wait()
}

// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
	s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
	if s == "" {
		t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
	}

	maxVram, err := strconv.ParseUint(s, 10, 64)
	if err != nil {
		t.Fatal(err)
	}

	type model struct {
		name string
		size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
	}

	smallModels := []model{
		{
			name: "orca-mini",
			size: 2992 * format.MebiByte,
		},
		{
			name: "phi",
			size: 2616 * format.MebiByte,
		},
		{
			name: "gemma:2b",
			size: 2364 * format.MebiByte,
		},
		{
			name: "stable-code:3b",
			size: 2608 * format.MebiByte,
		},
		{
			name: "starcoder2:3b",
			size: 2166 * format.MebiByte,
		},
	}
	mediumModels := []model{
		{
			name: "llama2",
			size: 5118 * format.MebiByte,
		},
		{
			name: "mistral",
			size: 4620 * format.MebiByte,
		},
		{
			name: "orca-mini:7b",
			size: 5118 * format.MebiByte,
		},
		{
			name: "dolphin-mistral",
			size: 4620 * format.MebiByte,
		},
		{
			name: "gemma:7b",
			size: 5000 * format.MebiByte,
		},
		{
			name: "codellama:7b",
			size: 5118 * format.MebiByte,
		},
	}

	// These seem to be too slow to be useful...
	// largeModels := []model{
	// 	{
	// 		name: "llama2:13b",
	// 		size: 7400 * format.MebiByte,
	// 	},
	// 	{
	// 		name: "codellama:13b",
	// 		size: 7400 * format.MebiByte,
	// 	},
	// 	{
	// 		name: "orca-mini:13b",
	// 		size: 7400 * format.MebiByte,
	// 	},
	// 	{
	// 		name: "gemma:7b",
	// 		size: 5000 * format.MebiByte,
	// 	},
	// 	{
	// 		name: "starcoder2:15b",
	// 		size: 9100 * format.MebiByte,
	// 	},
	// }

	var chosenModels []model
	switch {
	case maxVram < 10000*format.MebiByte:
		slog.Info("selecting small models")
		chosenModels = smallModels
	// case maxVram < 30000*format.MebiByte:
	default:
		slog.Info("selecting medium models")
		chosenModels = mediumModels
		// default:
		// 	slog.Info("selecting large models")
		// 	chosenModels = largModels
	}

	req, resp := GenerateRequests()

	for i := range req {
		if i > len(chosenModels) {
			break
		}
		req[i].Model = chosenModels[i].name
	}

	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	// Make sure all the models are pulled before we get started
	for _, r := range req {
		require.NoError(t, PullIfMissing(ctx, client, r.Model))
	}

	var wg sync.WaitGroup
	consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
	for i := 0; i < len(req); i++ {
		// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
		if i > 1 && consumed > maxVram {
			slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
			break
		}
		consumed += chosenModels[i].size
		slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))

		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			for j := 0; j < 3; j++ {
				slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
				DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
			}
		}(i)
	}
	go func() {
		for {
			time.Sleep(2 * time.Second)
			select {
			case <-ctx.Done():
				return
			default:
				models, err := client.ListRunning(ctx)
				if err != nil {
					slog.Warn("failed to list running models", "error", err)
					continue
				}
				for _, m := range models.Models {
					slog.Info("loaded model snapshot", "model", m)
				}
			}
		}
	}()
	wg.Wait()
}