samplers_benchmark_test.go 2.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package sample

import (
	"fmt"
	"math/rand"
	"testing"
)

func BenchmarkWeightedSampler(b *testing.B) {
	sizes := []int{10, 100, 1000, 10000}

	for _, size := range sizes {
		b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
			logits := make([]float32, size)
			for i := range logits {
				logits[i] = float32(rand.Float64()*10 - 5)
			}

			sampler := NewSampler(0.8, 0, 0, 0, 42)
			b.ResetTimer()
			for b.Loop() {
				_, err := sampler.Sample(logits)
				if err != nil {
					b.Fatalf("Sampling failed: %v", err)
				}
			}
		})
	}

	configs := []struct {
		name        string
		temperature float32
		topK        int
		topP        float32
		minP        float32
		seed        int
	}{
		{"Greedy", 0, -1, 0, 0, -1},
		{"Temperature", 0.8, -1, 0, 0, -1},
		{"TopK", 0.8, 50, 0, 0, -1},
		{"TopP", 0.8, -1, 0.9, 0, -1},
		{"MinP", 0.8, -1, 0, 0.05, -1},
		{"WithSeed", 0.8, 50, 0, 0, 42},
	}

	// Fixed size for common vocab size
	size := 128000
	logits := make([]float32, size)
	for i := range logits {
		logits[i] = float32(rand.Float64()*10 - 5)
	}

	for _, tc := range configs {
		b.Run("Config"+tc.name, func(b *testing.B) {
			sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
			sampler.Sample(logits)

			b.ResetTimer()

			for b.Loop() {
				_, err := sampler.Sample(logits)
				if err != nil {
					b.Fatalf("Sampling failed: %v", err)
				}
			}
		})
	}

	// Test with combined transforms separately - topK influences performance greatly
	b.Run("TransformCombined", func(b *testing.B) {
		sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
		b.ResetTimer()

		for b.Loop() {
			_, err := sampler.Sample(logits)
			if err != nil {
				b.Fatalf("Sampling failed: %v", err)
			}
		}
	})
}

func BenchmarkGreedySampler(b *testing.B) {
	sizes := []int{10, 100, 1000, 10000, 100000}

	for _, size := range sizes {
		b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
			logits := make([]float32, size)
			for i := range logits {
				logits[i] = float32(rand.Float64()*10 - 5)
			}

			sampler := NewSampler(0, -1, 0, 0, -1)
			b.ResetTimer()

			for b.Loop() {
				_, err := sampler.Sample(logits)
				if err != nil {
					b.Fatalf("Sampling failed: %v", err)
				}
			}
		})
	}
}