transforms_test.go 9 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math"
	"math/rand/v2"
	"testing"
)

9
10
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
11
	tokens := make([]token, len(values))
12
	for i, v := range values {
13
		tokens[i] = token{
14
			id:    int32(i),
15
			value: v,
16
17
18
19
20
21
		}
	}
	return tokens
}

// Helper to compare logit slices
22
func compareLogits(t *testing.T, name string, want []float32, got []token) {
23
24
25
26
27
28
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
29
		if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
30
31
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
32
33
34
	}
}

35
36
func TestTemperature(t *testing.T) {
	input := []float32{1.0, 4.0, -2.0, 0.0}
37
38
	tokens := toTokens(input)
	temperature(tokens, 0.5)
39
	want := []float32{2.0, 8.0, -4.0, 0.0}
40
	compareLogits(t, "temperature(0.5)", want, tokens)
41

42
43
44
	input = []float32{1.0, 4.0, -2.0, 0.0}
	tokens = toTokens(input)
	temperature(tokens, 1.0)
45
	want = []float32{1.0, 4.0, -2.0, 0.0}
46
	compareLogits(t, "temperature(1)", want, tokens)
47

48
49
50
	input = []float32{1.0, 4.0, -2.0, 0.0}
	tokens = toTokens(input)
	temperature(tokens, 0.0)
51
	want = []float32{1e7, 4e7, -2e7, 0.0}
52
	compareLogits(t, "temperature(0)", want, tokens)
53
}
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
func TestSoftmax(t *testing.T) {
	tests := []struct {
		name     string
		input    []float32
		expected []float32
	}{
		{
			name:     "correctness softmax",
			input:    []float32{1, -2, 3, 0},
			expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
		},
		{
			name:  "normal distribution",
			input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
		},
		{
			name:  "single value",
			input: []float32{1.0},
		},
		{
			name:  "identical values",
			input: []float32{0.9, 0.9, 0.9},
		},
		{
			name:  "large values",
			input: []float32{1000.0, 2000.0, 3000.0},
		},
		{
			name:  "small values",
			input: []float32{1e-6, 2e-6, 3e-6},
		},
		{
			name:  "negative values",
			input: []float32{-1.0, -2.0, -3.0},
		},
		{
			name:  "mixed values",
			input: []float32{-100.0, 0.0, 100.0},
		},
94
	}
95
96
97

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
98
99
			tokens := toTokens(tt.input)
			softmax(tokens)
100
101

			if tt.expected != nil {
102
				compareLogits(t, tt.name, tt.expected, tokens)
103
104
105
106
107
				return
			}

			// Check probabilities sum to 1
			var sum float32
108
			for _, token := range tokens {
109
110
111
112
113
114
115
116
117
				sum += token.value
				if token.value < 0 || token.value > 1 {
					t.Errorf("probability out of range [0,1]: got %f", token.value)
				}
			}
			if math.Abs(float64(sum-1.0)) > 1e-6 {
				t.Errorf("probabilities don't sum to 1: got %f", sum)
			}
		})
118
	}
119
}
120

121
func TestTopK(t *testing.T) {
122
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
123
124
125
126
	tokens := toTokens(input)
	tokens = topK(tokens, 5)
	if len(tokens) != 5 {
		t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
127
	}
128
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
129
	compareLogits(t, "topK(3)", want, tokens)
130

131
132
133
134
	tokens = toTokens(input)
	tokens = topK(tokens, 20)
	if len(tokens) != len(input) {
		t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
135
	}
ParthSareen's avatar
ParthSareen committed
136
137
138

	input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
	want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
139
140
141
142
	tokens = toTokens(input)
	tokens = topK(tokens, -1)
	if len(tokens) != len(input) {
		t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
ParthSareen's avatar
ParthSareen committed
143
	}
144
	compareLogits(t, "topK(-1)", want, tokens)
ParthSareen's avatar
ParthSareen committed
145
146
147

	input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
	want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
148
149
150
151
152
153
154
155
156
157
158
159
	tokens = toTokens(input)
	tokens = topK(tokens, 0)
	if len(tokens) != len(input) {
		t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
	}
	compareLogits(t, "topK(-1)", want, tokens)

	input = []float32{-1e7, -2e7, -3e7, -4e7}
	tokens = toTokens(input)
	tokens = topK(tokens, 1)
	if len(tokens) < 1 {
		t.Error("topK should keep at least one token")
ParthSareen's avatar
ParthSareen committed
160
	}
161
162
163
}

func TestTopP(t *testing.T) {
164
	input := []float32{-3, -2, -1, 0, 1, 2, 4}
165
	tokens := toTokens(input)
166
167

	// First apply temperature and softmax to get probabilities
168
	softmax(tokens)
ParthSareen's avatar
ParthSareen committed
169
	tokens = topK(tokens, 20)
170

171
172
	// Test with very high p value
	got := topP(tokens, 1.0)
173

174
175
176
177
178
179
180
181
182
	// Should keep all tokens since p is 1
	if len(got) != len(input) {
		t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
	}

	// Test with normal p value
	got = topP(tokens, 0.95)

	if len(got) > 3 {
183
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
184
		t.Logf("got: %v", got)
185
186
187
	}

	// Test edge case - ensure at least one token remains
188
	input = []float32{-1e6, -1e6, -1e7}
189
	tokens = toTokens(input)
190
	tokens = topK(tokens, 20)
191
	softmax(tokens)
192
193
	got = topP(tokens, 0.0)
	if len(got) < 1 {
194
		t.Error("topP should keep at least one token")
195
	}
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

	// Test with zero p value
	got = topP(tokens, 0.0)

	// Should keep only the highest probability token
	if len(got) != 1 {
		t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
		t.Logf("got: %v", got)
	}

	tokens = toTokens(input)
	tokens = topK(tokens, 20)
	softmax(tokens)
	got = topP(tokens, 1e-10)
	if len(got) == 0 {
		t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
		t.Logf("got: %v", got)
	}
214
215
216
}

func TestMinP(t *testing.T) {
217
	input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
218
	tokens := toTokens(input)
219
220

	// First apply temperature and softmax
221
222
	tokens = topK(tokens, 20)
	softmax(tokens)
223

224
225
226
227
228
229
230
231
232
233
234
	tokens = minP(tokens, 1.0)

	if len(tokens) != 1 {
		t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
	}

	// Test with normal p value
	tokens = toTokens(input) // Reset tokens
	tokens = topK(tokens, 20)
	softmax(tokens)
	tokens = minP(tokens, 0.2)
235
236

	// Should keep tokens with prob >= 0.2 * max_prob
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
	if len(tokens) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
		t.Logf("got: %v", tokens)
	}

	// Test with zero p value
	tokens = toTokens(input) // Reset tokens
	tokens = topK(tokens, 20)
	softmax(tokens)
	tokens = minP(tokens, 0.0)

	// Should keep only the highest probability token
	if len(tokens) != len(input) {
		t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
		t.Logf("got: %v", tokens)
	}

254
255
256
257
258
259
260
261
262
263
264
265
	// Test with single token
	tokens = toTokens(input[:1])
	tokens = topK(tokens, 20)
	softmax(tokens)
	tokens = minP(tokens, 0.1)

	// Should keep only the highest probability token
	if len(tokens) != 1 {
		t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
		t.Logf("got: %v", tokens)
	}

266
267
268
269
270
271
	input = []float32{1e-10, 1e-10, 1e-10}
	tokens = toTokens(input)
	softmax(tokens)
	tokens = minP(tokens, 1.0)
	if len(tokens) < 1 {
		t.Error("minP should keep at least one token even with extreme probabilities")
272
		got := minP(tokens, 1.0)
273

274
275
276
		if len(got) != 1 {
			t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
		}
277

278
279
		// Test with normal p value
		got = minP(tokens, 0.2)
280

281
282
283
284
		// Should keep tokens with prob >= 0.2 * max_prob
		if len(got) > 3 {
			t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
			t.Logf("got: %v", got)
285
		}
286

287
288
289
290
291
292
293
294
295
		// Test with zero p value
		got = minP(tokens, 0.0)

		// Should keep only the highest probability token
		if len(got) != len(tokens) {
			t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
			t.Logf("got: %v", got)
		}
	}
296
297
298
299
}

func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
300
	tokens := make([]token, 1<<16)
301
	for i := range tokens {
302
		tokens[i] = token{
303
304
305
			id:    int32(i),
			value: rand.Float32(),
		}
306
	}
307

308
	tokensCopy := make([]token, len(tokens))
309
310
311
312
313
314
315
316
317

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

318
319
320
321
322
323
324
325
	b.Run("Softmax", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			softmax(tokensCopy)
		}
	})

326
327
328
329
	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
330
			tokens = topK(tokensCopy, 10)
331
332
333
334
335
336
337
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
338
			tokens = topP(tokensCopy, 0.9)
339
340
341
342
343
344
345
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
346
			tokens = minP(tokensCopy, 0.2)
347
348
349
350
351
352
353
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
354
			tokens = topK(tokensCopy, 200000)
355
356
		}
	})
357
}