model_perf_test.go 8.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
//go:build integration && perf

package integration

import (
	"context"
	"fmt"
	"io/ioutil"
	"log/slog"
	"math"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/format"
)

var (
	// Models that don't work reliably with the large context prompt in this test case
	longContextFlakes = []string{
		"granite-code:latest",
		"nemotron-mini:latest",
		"falcon:latest",  // 2k model
		"falcon2:latest", // 2k model
		"minicpm-v:latest",
		"qwen:latest",
		"solar-pro:latest",
	}
)

// Note: this test case can take a long time to run, particularly on models with
// large contexts.  Run with -timeout set to a large value to get reasonable coverage
// Example usage:
//
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) {
43
44
45
46
47
48
49
50
51
52
53
54
	if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
		doModelPerfTest(t, ollamaEngineChatModels)
	} else {
		doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
	}
}

func TestLibraryModelsPerf(t *testing.T) {
	doModelPerfTest(t, libraryChatModels)
}

func doModelPerfTest(t *testing.T, chatModels []string) {
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
	softTimeout, hardTimeout := getTimeouts(t)
	slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
	ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	// TODO use info API eventually
	var maxVram uint64
	var err error
	if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
		maxVram, err = strconv.ParseUint(s, 10, 64)
		if err != nil {
			t.Fatalf("invalid  OLLAMA_MAX_VRAM %v", err)
		}
	} else {
		slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
	}

	data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
	if err != nil {
		t.Fatalf("failed to open test data file: %s", err)
	}
	longPrompt := "summarize the following: " + string(data)

80
	targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
81
82

	for _, model := range chatModels {
83
84
85
		if !strings.Contains(model, ":") {
			model = model + ":latest"
		}
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
		t.Run(model, func(t *testing.T) {
			if time.Now().Sub(started) > softTimeout {
				t.Skip("skipping remaining tests to avoid excessive runtime")
			}
			if err := PullIfMissing(ctx, client, model); err != nil {
				t.Fatalf("pull failed %s", err)
			}
			var maxContext int

			resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
			if err != nil {
				t.Fatalf("show failed: %s", err)
			}
			arch := resp.ModelInfo["general.architecture"].(string)
			maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
101
102
103
			if targetArch != "" && arch != targetArch {
				t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
			}
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

			if maxVram > 0 {
				resp, err := client.List(ctx)
				if err != nil {
					t.Fatalf("list models failed %v", err)
				}
				for _, m := range resp.Models {
					// For these tests we want to exercise a some amount of overflow on the CPU
					if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
						t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
					}
				}
			}
			slog.Info("scneario", "model", model, "max_context", maxContext)
			loaded := false
			defer func() {
				// best effort unload once we're done with the model
				if loaded {
					client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
				}
			}()

			// Some models don't handle the long context data well so skip them to avoid flaky test results
			longContextFlake := false
			for _, flake := range longContextFlakes {
				if model == flake {
					longContextFlake = true
					break
				}
			}

			// iterate through a few context sizes for coverage without excessive runtime
			var contexts []int
			keepGoing := true
			if maxContext > 16384 {
				contexts = []int{4096, 8192, 16384, maxContext}
			} else if maxContext > 8192 {
				contexts = []int{4096, 8192, maxContext}
			} else if maxContext > 4096 {
				contexts = []int{4096, maxContext}
			} else if maxContext > 0 {
				contexts = []int{maxContext}
			} else {
				t.Fatal("unknown max context size")
			}
			for _, numCtx := range contexts {
				if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
					break
				}
				skipLongPrompt := false

				// Workaround bug 11172 temporarily...
				maxPrompt := longPrompt
				// If we fill the context too full with the prompt, many models
				// quickly hit context shifting and go bad.
				if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
					maxPrompt = maxPrompt[:numCtx*2]
				}

				testCases := []struct {
					prompt  string
					anyResp []string
				}{
167
168
					{blueSkyPrompt, blueSkyExpected},
					{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
				}
				var gpuPercent int
				for _, tc := range testCases {
					if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
						slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
						continue
					}
					req := api.GenerateRequest{
						Model:     model,
						Prompt:    tc.prompt,
						KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
						Options: map[string]interface{}{
							"temperature": 0,
							"seed":        123,
							"num_ctx":     numCtx,
						},
					}
					atLeastOne := false
					var resp api.GenerateResponse

					stream := false
					req.Stream = &stream

					// Avoid potentially getting stuck indefinitely
					limit := 5 * time.Minute
					genCtx, cancel := context.WithDeadlineCause(
						ctx,
						time.Now().Add(limit),
						fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
					)
					defer cancel()

					err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
						resp = rsp
						return nil
					})
					if err != nil {
						// Avoid excessive test runs, but don't consider a failure with massive context
						if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
							slog.Warn("max context was taking too long, skipping", "error", err)
							keepGoing = false
							skipLongPrompt = true
							continue
						}
						t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
					}
					loaded = true
					for _, expResp := range tc.anyResp {
						if strings.Contains(strings.ToLower(resp.Response), expResp) {
							atLeastOne = true
							break
						}
					}
					if !atLeastOne {
						t.Fatalf("response didn't contain expected values: ctx:%d  expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
					}
					models, err := client.ListRunning(ctx)
					if err != nil {
						slog.Warn("failed to list running models", "error", err)
						continue
					}
					if len(models.Models) > 1 {
						slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
					}
					for _, m := range models.Models {
						if m.Name == model {
							if m.SizeVRAM == 0 {
								slog.Info("Model fully loaded into CPU")
								gpuPercent = 0
								keepGoing = false
								skipLongPrompt = true
							} else if m.SizeVRAM == m.Size {
								slog.Info("Model fully loaded into GPU")
								gpuPercent = 100
							} else {
								sizeCPU := m.Size - m.SizeVRAM
								cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
								gpuPercent = int(100 - cpuPercent)
								slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
								keepGoing = false

								// Heuristic to avoid excessive test run time
								if gpuPercent < 90 {
									skipLongPrompt = true
								}
							}
						}
					}
257
					// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
258
259
260
261
					fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
						"MODEL",
						"CONTEXT",
						"GPU PERCENT",
262
						"APPROX PROMPT COUNT",
263
264
265
266
267
268
269
270
						"LOAD TIME",
						"PROMPT EVAL TPS",
						"EVAL TPS",
					)
					fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
						model,
						numCtx,
						gpuPercent,
271
						(resp.PromptEvalCount/10)*10,
272
273
274
275
276
277
278
279
280
						float64(resp.LoadDuration)/1000000000.0,
						float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
						float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
					)
				}
			}
		})
	}
}