transforms_test.go 8.2 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math"
	"math/rand/v2"
	"testing"
)

9
10
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
11
	tokens := make([]token, len(values))
12
	for i, v := range values {
13
		tokens[i] = token{
14
			id:    int32(i),
15
			value: v,
16
17
18
19
20
21
		}
	}
	return tokens
}

// Helper to compare logit slices
22
func compareLogits(t *testing.T, name string, want []float32, got []token) {
23
24
25
26
27
28
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
29
		if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
30
31
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
32
33
34
	}
}

35
36
func TestTemperature(t *testing.T) {
	input := []float32{1.0, 4.0, -2.0, 0.0}
37
38
	tokens := toTokens(input)
	temperature(tokens, 0.5)
39
	want := []float32{2.0, 8.0, -4.0, 0.0}
40
	compareLogits(t, "temperature(0.5)", want, tokens)
41

42
43
44
	input = []float32{1.0, 4.0, -2.0, 0.0}
	tokens = toTokens(input)
	temperature(tokens, 1.0)
45
	want = []float32{1.0, 4.0, -2.0, 0.0}
46
	compareLogits(t, "temperature(1)", want, tokens)
47

48
49
50
	input = []float32{1.0, 4.0, -2.0, 0.0}
	tokens = toTokens(input)
	temperature(tokens, 0.0)
51
	want = []float32{1e7, 4e7, -2e7, 0.0}
52
	compareLogits(t, "temperature(0)", want, tokens)
53
}
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
func TestSoftmax(t *testing.T) {
	tests := []struct {
		name     string
		input    []float32
		expected []float32
	}{
		{
			name:     "correctness softmax",
			input:    []float32{1, -2, 3, 0},
			expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
		},
		{
			name:  "normal distribution",
			input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
		},
		{
			name:  "single value",
			input: []float32{1.0},
		},
		{
			name:  "identical values",
			input: []float32{0.9, 0.9, 0.9},
		},
		{
			name:  "large values",
			input: []float32{1000.0, 2000.0, 3000.0},
		},
		{
			name:  "small values",
			input: []float32{1e-6, 2e-6, 3e-6},
		},
		{
			name:  "negative values",
			input: []float32{-1.0, -2.0, -3.0},
		},
		{
			name:  "mixed values",
			input: []float32{-100.0, 0.0, 100.0},
		},
94
	}
95
96
97

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
98
99
			tokens := toTokens(tt.input)
			softmax(tokens)
100
101

			if tt.expected != nil {
102
				compareLogits(t, tt.name, tt.expected, tokens)
103
104
105
106
107
				return
			}

			// Check probabilities sum to 1
			var sum float32
108
			for _, token := range tokens {
109
110
111
112
113
114
115
116
117
				sum += token.value
				if token.value < 0 || token.value > 1 {
					t.Errorf("probability out of range [0,1]: got %f", token.value)
				}
			}
			if math.Abs(float64(sum-1.0)) > 1e-6 {
				t.Errorf("probabilities don't sum to 1: got %f", sum)
			}
		})
118
	}
119
}
120

121
func TestTopK(t *testing.T) {
122
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
123
124
125
126
	tokens := toTokens(input)
	tokens = topK(tokens, 5)
	if len(tokens) != 5 {
		t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
127
	}
128
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
129
	compareLogits(t, "topK(3)", want, tokens)
130

131
132
133
134
	tokens = toTokens(input)
	tokens = topK(tokens, 20)
	if len(tokens) != len(input) {
		t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
135
	}
ParthSareen's avatar
ParthSareen committed
136
137
138

	input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
	want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
139
140
141
142
	tokens = toTokens(input)
	tokens = topK(tokens, -1)
	if len(tokens) != len(input) {
		t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
ParthSareen's avatar
ParthSareen committed
143
	}
144
	compareLogits(t, "topK(-1)", want, tokens)
ParthSareen's avatar
ParthSareen committed
145
146
147

	input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
	want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
148
149
150
151
152
153
154
155
156
157
158
159
	tokens = toTokens(input)
	tokens = topK(tokens, 0)
	if len(tokens) != len(input) {
		t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
	}
	compareLogits(t, "topK(-1)", want, tokens)

	input = []float32{-1e7, -2e7, -3e7, -4e7}
	tokens = toTokens(input)
	tokens = topK(tokens, 1)
	if len(tokens) < 1 {
		t.Error("topK should keep at least one token")
ParthSareen's avatar
ParthSareen committed
160
	}
161
162
163
}

func TestTopP(t *testing.T) {
164
	input := []float32{-3, -2, -1, 0, 1, 2, 4}
165
	tokens := toTokens(input)
166
167

	// First apply temperature and softmax to get probabilities
168
	softmax(tokens)
ParthSareen's avatar
ParthSareen committed
169
	tokens = topK(tokens, 20)
170
171

	// Then apply topP
172
	tokens = topP(tokens, 0.95)
173
174

	// Should keep tokens until cumsum > 0.95
175
176
177
178
179
180
181
182
183
184
185
186
	if len(tokens) > 3 {
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
		t.Logf("got: %v", tokens)
	}

	// Test edge case - ensure at least one token remains
	input = []float32{-1e6, -1e6, -1e6} // One dominant token
	tokens = toTokens(input)
	softmax(tokens)
	tokens = topP(tokens, 0.0) // Very small p
	if len(tokens) < 1 {
		t.Error("topP should keep at least one token")
187
188
189
190
	}
}

func TestMinP(t *testing.T) {
191
	input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
192
	tokens := toTokens(input)
193
194

	// First apply temperature and softmax
195
196
	tokens = topK(tokens, 20)
	softmax(tokens)
197

198
199
200
201
202
203
204
205
206
207
208
	tokens = minP(tokens, 1.0)

	if len(tokens) != 1 {
		t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
	}

	// Test with normal p value
	tokens = toTokens(input) // Reset tokens
	tokens = topK(tokens, 20)
	softmax(tokens)
	tokens = minP(tokens, 0.2)
209
210

	// Should keep tokens with prob >= 0.2 * max_prob
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
	if len(tokens) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
		t.Logf("got: %v", tokens)
	}

	// Test with zero p value
	tokens = toTokens(input) // Reset tokens
	tokens = topK(tokens, 20)
	softmax(tokens)
	tokens = minP(tokens, 0.0)

	// Should keep only the highest probability token
	if len(tokens) != len(input) {
		t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
		t.Logf("got: %v", tokens)
	}

	input = []float32{1e-10, 1e-10, 1e-10}
	tokens = toTokens(input)
	softmax(tokens)
	tokens = minP(tokens, 1.0)
	if len(tokens) < 1 {
		t.Error("minP should keep at least one token even with extreme probabilities")
234
235
236
	}
}

237
func TestSortLogits(t *testing.T) {
238
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
239
	tokens := toTokens(input)
240

ParthSareen's avatar
ParthSareen committed
241
	tokens = topK(tokens, 20)
242

243
244
245
246
247
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > tokens[i-1].value {
			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
				i, tokens[i].value, tokens[i-1].value)
		}
248
249
	}

250
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
251
252
253
254
255
	compareLogits(t, "sortLogits", want, tokens)
}

func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
256
	tokens := make([]token, 1<<16)
257
	for i := range tokens {
258
		tokens[i] = token{
259
260
261
			id:    int32(i),
			value: rand.Float32(),
		}
262
	}
263

264
	tokensCopy := make([]token, len(tokens))
265
266
267
268
269
270
271
272
273

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

274
275
276
277
278
279
280
281
	b.Run("Softmax", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			softmax(tokensCopy)
		}
	})

282
283
284
285
	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
286
			tokens = topK(tokensCopy, 10)
287
288
289
290
291
292
293
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
294
			tokens = topP(tokensCopy, 0.9)
295
296
297
298
299
300
301
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
302
			tokens = minP(tokensCopy, 0.2)
303
304
305
306
307
308
309
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
310
			tokens = topK(tokensCopy, 200000)
311
312
		}
	})
313
}