samplers_test.go 2.79 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math/rand/v2"
	"testing"
)

func TestWeighted(t *testing.T) {
9
10
11
	logits := []float32{-10, 3, -10, -10}
	sampler := NewSampler(0, 0, 0, 0, 0)
	got, err := sampler.Sample(logits)
12
13
14
15
16
17
18
19
20
	if err != nil {
		t.Error(err)
		return
	}
	want := int32(1)
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}

21
22
23
	logits = []float32{-100, -10, 0, 10}
	sampler = NewSampler(0, 0, 0, 0, 0)
	got, err = sampler.Sample(logits)
24
25
26
27
	if err != nil {
		t.Error(err)
		return
	}
28
	want = int32(3) // Should pick highest probability with this r value
29
30
31
32
33
34
35
36
37
38
39
40
41
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}
}

func TestNewSampler(t *testing.T) {
	tests := []struct {
		name        string
		temperature float32
		topK        int
		topP        float32
		minP        float32
		seed        int
42
		wantGreedy  bool // Instead of wantErr, check if we get greedy sampler
43
44
45
46
	}{
		{
			name:        "temperature",
			temperature: 0.5,
47
			wantGreedy:  false,
48
49
		},
		{
50
51
52
			name:        "zero temperature - greedy",
			temperature: 0,
			wantGreedy:  true,
53
54
		},
		{
55
			name:        "top k",
56
			temperature: 0.1,
57
			topK:        10,
58
			wantGreedy:  false,
59
60
		},
		{
61
			name:        "top p",
62
			temperature: 0.1,
63
			topP:        0.9,
64
			wantGreedy:  false,
65
66
		},
		{
67
			name:        "min p",
68
			temperature: 0.1,
69
			minP:        0.2,
70
			wantGreedy:  false,
71
72
		},
		{
73
74
75
76
			name:        "seed - weighted",
			temperature: 0.1,
			seed:        42,
			wantGreedy:  false,
77
78
79
80
81
82
83
84
		},
		{
			name:        "default values",
			temperature: 0.8,
			topK:        40,
			topP:        0.9,
			minP:        0.0,
			seed:        0,
85
			wantGreedy:  false,
86
87
		},
		{
88
			name:        "all zeroes - greedy",
89
90
91
92
93
			temperature: 0.0,
			topK:        0,
			topP:        0.0,
			minP:        0.0,
			seed:        0,
94
			wantGreedy:  true,
95
96
97
98
99
100
101
102
		},
		{
			name:        "all transforms",
			temperature: 0.8,
			topK:        50,
			topP:        0.95,
			minP:        0.1,
			seed:        42,
103
			wantGreedy:  false,
104
105
106
107
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
108
109
110
111
			sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
			_, isGreedy := sampler.(*greedy)
			if isGreedy != tt.wantGreedy {
				t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
112
113
114
115
116
117
			}
		})
	}
}

func BenchmarkSample(b *testing.B) {
118
	weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
119
	samplers := map[string]Sampler{
120
121
		"Greedy":   NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
		"Weighted": weighted,
122
123
	}

124
	// Generate random logits for benchmarking
125
126
127
128
129
130
131
132
	logits := make([]float32, 1<<16)
	for i := range logits {
		logits[i] = rand.Float32()
	}

	for name, s := range samplers {
		b.Run(name, func(b *testing.B) {
			b.ResetTimer()
133
			for b.Loop() {
134
135
136
137
138
139
140
				if _, err := s.Sample(logits); err != nil {
					b.Error(err)
				}
			}
		})
	}
}