transforms_test.go 6.72 KB
Newer Older
1
2
3
package sample

import (
4
5
	"encoding/binary"
	"errors"
6
7
	"math"
	"math/rand/v2"
8
9
10
	"os"
	"path/filepath"
	"runtime"
11
12
13
	"testing"
)

14
15
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
16
	tokens := make([]token, len(values))
17
	for i, v := range values {
18
		tokens[i] = token{
19
			id:    int32(i),
20
			value: v,
21
22
23
24
25
26
		}
	}
	return tokens
}

// Helper to compare logit slices
27
func compareLogits(t *testing.T, name string, want []float32, got []token) {
28
29
30
31
32
33
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
34
		if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
35
36
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
37
38
39
	}
}

40
func TestTemperatureAndSoftmax(t *testing.T) {
41
	input := []float32{1, 4, -2, 0}
42
	got := temperature(toTokens(input), 0.5)
43

44
45
46
47
48
	// Check probabilities sum to 1
	var sum float32
	for _, token := range got {
		sum += token.value
	}
49
	if math.Abs(float64(sum-1.0)) > 1e-6 {
50
		t.Errorf("probabilities don't sum to 1: got %f", sum)
51
52
	}

53
54
55
56
57
58
	got = temperature(toTokens(input), 1)
	// Check probabilities sum to 1
	sum = 0.0
	for _, token := range got {
		sum += token.value
	}
59
	if math.Abs(float64(sum-1.0)) > 1e-6 {
60
		t.Errorf("probabilities don't sum to 1: got %f", sum)
61
	}
62
}
63

64
func TestTopK(t *testing.T) {
65
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
66

67
	// Test k=3
68
69
70
	got := topK(toTokens(input), 5)
	if len(got) != 5 {
		t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
71
	}
72
73
	// Should keep highest 3 values in descending order
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
74
75
	compareLogits(t, "topK(3)", want, got)

76
77
78
79
	got = topK(toTokens(input), 20)
	if len(got) != len(input) {
		t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
	}
80
81
82
}

func TestTopP(t *testing.T) {
83
	input := []float32{-3, -2, -1, 0, 1, 2, 4}
84
	tokens := toTokens(input)
85
86
87
88
89
90
91
92
93
94
95
96

	// First apply temperature and softmax to get probabilities
	tokens = temperature(tokens, 1)
	sortLogits(tokens)

	// Then apply topP
	got := topP(tokens, 0.95)

	// Should keep tokens until cumsum > 0.95
	if len(got) > 3 {
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
		t.Logf("got: %v", got)
97
98
99
100
	}
}

func TestMinP(t *testing.T) {
101
	input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
102
	tokens := toTokens(input)
103
104
105
106
107
108
109
110
111
112

	// First apply temperature and softmax
	tokens = temperature(tokens, 1)

	// Then apply minP
	got := minP(tokens, 0.2)

	// Should keep tokens with prob >= 0.2 * max_prob
	if len(got) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
113
114
115
	}
}

116
func TestSortLogits(t *testing.T) {
117
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
118
	tokens := toTokens(input)
119
120

	sortLogits(tokens)
121

122
123
124
125
126
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > tokens[i-1].value {
			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
				i, tokens[i].value, tokens[i-1].value)
		}
127
128
	}

129
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
130
131
132
	compareLogits(t, "sortLogits", want, tokens)
}

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
func TestSortLogitsWithRealData(t *testing.T) {
	// This will be populated from testdata/logits.bin
	// Format: 32-bit float array in binary format
	logits, err := loadTestLogits(t)
	if err != nil {
		t.Skipf("Skipping real logit test: %v", err)
		return
	}

	tokens := toTokens(logits)
	sortLogits(tokens)

	// Calculate n for verification
	n := int(math.Sqrt(float64(len(tokens)))) + 1
	if n > 1000 {
		n = 1000
	} else if n < 100 {
		n = 100
	}

	t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)

	// Only verify the top n elements are sorted (which is what we guarantee)
	// This is much faster than checking the entire array
	topN := tokens[:n]
	for i := 1; i < len(topN); i++ {
		if topN[i].value > topN[i-1].value {
			t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
				n, i, topN[i].value, topN[i-1].value)
		}
	}

	// Verify we didn't lose any high value tokens by checking that
	// all tokens after position n are <= the nth token
	// Do this in chunks to avoid timeouts on large arrays
	nthValue := tokens[n-1].value
	const chunkSize = 1000

	for start := n; start < len(tokens); start += chunkSize {
		end := min(start+chunkSize, len(tokens))
		for i := start; i < end; i++ {
			if tokens[i].value > nthValue {
				t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
					n, i, tokens[i].value, nthValue)
			}
		}
	}
}

// loadTestLogits loads logit test data from testdata/logits.bin
func loadTestLogits(t *testing.T) ([]float32, error) {
	t.Helper()

	_, currFile, _, ok := runtime.Caller(0)
	if !ok {
		return nil, errors.New("could not determine test file path")
	}
	testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")

	file, err := os.Open(testDataPath)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	stat, err := file.Stat()
	if err != nil {
		return nil, err
	}

	numFloats := stat.Size() / 4 // each float32 is 4 bytes
	if numFloats*4 != stat.Size() {
		return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
	}

	logits := make([]float32, numFloats)
	for i := range logits {
		var val uint32
		if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
			return nil, err
		}
		logits[i] = math.Float32frombits(val)
	}

	if len(logits) == 0 {
		return nil, errors.New("logits.bin is empty")
	}

	return logits, nil
}

225
226
func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
227
	tokens := make([]token, 1<<16)
228
	for i := range tokens {
229
		tokens[i] = token{
230
231
232
			id:    int32(i),
			value: rand.Float32(),
		}
233
	}
234

235
	tokensCopy := make([]token, len(tokens))
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topK(tokensCopy, 10)
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topP(tokensCopy, 0.9)
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			minP(tokensCopy, 0.2)
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			sortLogits(tokensCopy)
		}
	})
276
}