".github/vscode:/vscode.git/clone" did not exist on "11b23ae97bba5990dc8e3f6e0c6278ddbb6f965d"
samplers_test.go 4.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package sample

import (
	"math"
	"math/rand/v2"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func TestWeighted(t *testing.T) {
	got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
	if err != nil {
		t.Error(err)
		return
	}
	want := int32(1)
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}

	got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
	if err == nil {
		t.Error("expected error for no valid tokens, got index", got)
	}

	seed := uint64(42)
	got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
	if err != nil {
		t.Error(err)
		return
	}
	// With seed 42, we expect a consistent sample
	want = int32(3) // This will be deterministic due to the seed
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}
}

type testTransform struct {
	id        int
	callOrder *[]int
}

func (ts *testTransform) Apply(logits []float64) []float64 {
	if ts.callOrder != nil {
		*ts.callOrder = append(*ts.callOrder, ts.id)
	}
	return logits
}

func TestSample(t *testing.T) {
	input := []float32{1, 2, 3, 4}

	var callOrder []int
	mock1 := &testTransform{
		id:        1,
		callOrder: &callOrder,
	}
	mock2 := &testTransform{
		id:        2,
		callOrder: &callOrder,
	}
	mock3 := &testTransform{
		id:        3,
		callOrder: &callOrder,
	}

69
	_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
	if err != nil {
		t.Error(err)
		return
	}
	wantOrder := []int{1, 2, 3}
	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
		t.Errorf("call order mismatch (-want +got):\n%s", diff)
	}
}

func TestNewSampler(t *testing.T) {
	tests := []struct {
		name        string
		temperature float32
		topK        int
		topP        float32
		minP        float32
		seed        int
		wantErr     bool
	}{
		{
91
92
93
			name: "no transforms",
			// temperature is 0, so greedy should be used
			wantErr: false,
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
		},
		{
			name:        "temperature",
			temperature: 0.5,
			wantErr:     false,
		},
		{
			name:        "invalid temperature negative",
			temperature: -1,
			wantErr:     true,
		},
		{
			name:        "invalid temperature too high",
			temperature: 2.1,
			wantErr:     true,
		},
		{
111
112
113
114
			name:        "top k",
			topK:        10,
			temperature: 0.8,
			wantErr:     false,
115
116
		},
		{
117
118
119
120
			name:        "invalid top k negative",
			topK:        -1,
			temperature: 0.8,
			wantErr:     true,
121
122
		},
		{
123
124
125
126
			name:        "top p",
			topP:        0.9,
			temperature: 0.8,
			wantErr:     false,
127
128
		},
		{
129
130
131
132
			name:        "invalid top p negative",
			topP:        -0.1,
			temperature: 0.8,
			wantErr:     true,
133
134
		},
		{
135
136
137
138
			name:        "invalid top p one",
			topP:        1.0,
			temperature: 0.8,
			wantErr:     true,
139
140
		},
		{
141
142
143
144
			name:        "min p",
			minP:        0.2,
			temperature: 0.8,
			wantErr:     false,
145
146
		},
		{
147
148
149
150
			name:        "invalid min p negative",
			minP:        -0.1,
			temperature: 0.8,
			wantErr:     true,
151
152
		},
		{
153
154
155
156
			name:        "invalid min p one",
			minP:        1.0,
			temperature: 0.8,
			wantErr:     true,
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
		},
		{
			name:        "default values",
			temperature: 0.8,
			topK:        40,
			topP:        0.9,
			minP:        0.0,
			seed:        0,
			wantErr:     false,
		},
		{
			name:        "all zeroes",
			temperature: 0.0,
			topK:        0,
			topP:        0.0,
			minP:        0.0,
			seed:        0,
174
			wantErr:     false, // all zeroes means no transforms
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
		},
		{
			name:        "all transforms",
			temperature: 0.8,
			topK:        50,
			topP:        0.95,
			minP:        0.1,
			seed:        42,
			wantErr:     false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func BenchmarkSample(b *testing.B) {
	transforms := []Transform{
		Temperature(0.5),
		TopK(10),
		TopP(0.9),
		MinP(0.2),
	}

	samplers := map[string]Sampler{
206
		"Greedy":   Greedy(),
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
		"Weighted": Weighted(nil, transforms...),
	}

	logits := make([]float32, 1<<16)
	for i := range logits {
		logits[i] = rand.Float32()
	}

	for name, s := range samplers {
		b.Run(name, func(b *testing.B) {
			b.ResetTimer()
			for range b.N {
				if _, err := s.Sample(logits); err != nil {
					b.Error(err)
				}
			}
		})
	}
}