"server/vscode:/vscode.git/clone" did not exist on "1e7f62cb429e5a962dd9c448e7b1b3371879e48b"
server_benchmark_test.go 4.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
package benchmark

import (
	"context"
	"flag"
	"fmt"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
)

// Command line flags
var modelFlag string

func init() {
	flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
	flag.Lookup("m").DefValue = "model"
}

// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
	if modelFlag == "" {
		b.Fatal("Error: -m flag is required for benchmark tests")
	}
	return modelFlag
}

type TestCase struct {
	name      string
	prompt    string
	maxTokens int
}

// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
	start := time.Now()
	var ttft time.Duration
	var metrics api.Metrics

	err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
		if ttft == 0 && resp.Response != "" {
			ttft = time.Since(start)
		}
		if resp.Done {
			metrics = resp.Metrics
		}
		return nil
	})

	// Report custom metrics as part of the benchmark results
	b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
	b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")

	// Token throughput metrics
	promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
	genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
	b.ReportMetric(promptThroughput, "prompt_tok/s")
	b.ReportMetric(genThroughput, "gen_tok/s")

	// Token counts
	b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
	b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
	if err != nil {
		b.Fatal(err)
	}
}

// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
	client := setup(b)
	tests := []TestCase{
		{"short_prompt", "Write a long story", 100},
		{"medium_prompt", "Write a detailed economic analysis", 500},
		{"long_prompt", "Write a comprehensive AI research paper", 1000},
	}
	m := modelName(b)

	for _, tt := range tests {
		b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
81
			ctx := b.Context()
82
83
84
85
86
87
88
89
90
91
92
93
94

			// Set number of tokens as our throughput metric
			b.SetBytes(int64(tt.maxTokens))

			for b.Loop() {
				b.StopTimer()
				// Ensure model is unloaded before each iteration
				unload(client, m, b)
				b.StartTimer()

				req := &api.GenerateRequest{
					Model:   m,
					Prompt:  tt.prompt,
95
					Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
				}

				runGenerateBenchmark(b, ctx, client, req)
			}
		})
	}
}

// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
	client := setup(b)
	tests := []TestCase{
		{"short_prompt", "Write a long story", 100},
		{"medium_prompt", "Write a detailed economic analysis", 500},
		{"long_prompt", "Write a comprehensive AI research paper", 1000},
	}
	m := modelName(b)

	for _, tt := range tests {
		b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
116
			ctx := b.Context()
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

			// Pre-warm the model
			warmup(client, m, tt.prompt, b)

			// Set number of tokens as our throughput metric
			b.SetBytes(int64(tt.maxTokens))

			for b.Loop() {
				req := &api.GenerateRequest{
					Model:   m,
					Prompt:  tt.prompt,
					Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
				}

				runGenerateBenchmark(b, ctx, client, req)
			}
		})
	}
}

// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
	client, err := api.ClientFromEnvironment()
	if err != nil {
		b.Fatal(err)
	}
143
	if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
144
145
146
147
148
149
150
151
152
153
154
155
156
157
		b.Fatalf("Model unavailable: %v", err)
	}

	return client
}

// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
	for range 3 {
		err := client.Generate(
			context.Background(),
			&api.GenerateRequest{
				Model:   model,
				Prompt:  prompt,
158
				Options: map[string]any{"num_predict": 50, "temperature": 0.1},
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
			},
			func(api.GenerateResponse) error { return nil },
		)
		if err != nil {
			b.Logf("Error during model warm-up: %v", err)
		}
	}
}

// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
	req := &api.GenerateRequest{
		Model:     model,
		KeepAlive: &api.Duration{Duration: 0},
	}
	if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
		b.Logf("Unload error: %v", err)
	}
	time.Sleep(1 * time.Second)
}