transforms_test.go 4.23 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math"
	"math/rand/v2"
	"testing"
)

9
10
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
11
	tokens := make([]token, len(values))
12
	for i, v := range values {
13
		tokens[i] = token{
14
			id:    int32(i),
15
			value: v,
16
17
18
19
20
21
		}
	}
	return tokens
}

// Helper to compare logit slices
22
func compareLogits(t *testing.T, name string, want []float32, got []token) {
23
24
25
26
27
28
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
29
		if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
30
31
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
32
33
34
	}
}

35
func TestTemperatureAndSoftmax(t *testing.T) {
36
	input := []float32{1, 4, -2, 0}
37
	got := temperature(toTokens(input), 0.5)
38

39
40
41
42
43
	// Check probabilities sum to 1
	var sum float32
	for _, token := range got {
		sum += token.value
	}
44
	if math.Abs(float64(sum-1.0)) > 1e-6 {
45
		t.Errorf("probabilities don't sum to 1: got %f", sum)
46
47
	}

48
49
50
51
52
53
	got = temperature(toTokens(input), 1)
	// Check probabilities sum to 1
	sum = 0.0
	for _, token := range got {
		sum += token.value
	}
54
	if math.Abs(float64(sum-1.0)) > 1e-6 {
55
		t.Errorf("probabilities don't sum to 1: got %f", sum)
56
	}
57
}
58

59
func TestTopK(t *testing.T) {
60
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
61

62
	// Test k=3
63
64
65
	got := topK(toTokens(input), 5)
	if len(got) != 5 {
		t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
66
	}
67
68
	// Should keep highest 3 values in descending order
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
69
70
	compareLogits(t, "topK(3)", want, got)

71
72
73
74
	got = topK(toTokens(input), 20)
	if len(got) != len(input) {
		t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
	}
75
76
77
}

func TestTopP(t *testing.T) {
78
	input := []float32{-3, -2, -1, 0, 1, 2, 4}
79
	tokens := toTokens(input)
80
81
82
83
84
85
86
87
88
89
90
91

	// First apply temperature and softmax to get probabilities
	tokens = temperature(tokens, 1)
	sortLogits(tokens)

	// Then apply topP
	got := topP(tokens, 0.95)

	// Should keep tokens until cumsum > 0.95
	if len(got) > 3 {
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
		t.Logf("got: %v", got)
92
93
94
95
	}
}

func TestMinP(t *testing.T) {
96
	input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
97
	tokens := toTokens(input)
98
99
100
101
102
103
104
105
106
107

	// First apply temperature and softmax
	tokens = temperature(tokens, 1)

	// Then apply minP
	got := minP(tokens, 0.2)

	// Should keep tokens with prob >= 0.2 * max_prob
	if len(got) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
108
109
110
	}
}

111
func TestSortLogits(t *testing.T) {
112
	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
113
	tokens := toTokens(input)
114
115

	sortLogits(tokens)
116

117
118
119
120
121
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > tokens[i-1].value {
			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
				i, tokens[i].value, tokens[i-1].value)
		}
122
123
	}

124
	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
125
126
127
128
129
	compareLogits(t, "sortLogits", want, tokens)
}

func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
130
	tokens := make([]token, 1<<16)
131
	for i := range tokens {
132
		tokens[i] = token{
133
134
135
			id:    int32(i),
			value: rand.Float32(),
		}
136
	}
137

138
	tokensCopy := make([]token, len(tokens))
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topK(tokensCopy, 10)
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topP(tokensCopy, 0.9)
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			minP(tokensCopy, 0.2)
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			sortLogits(tokensCopy)
		}
	})
179
}