transforms_test.go 3.98 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math"
	"math/rand/v2"
	"testing"
)

9
// Helper to convert float64 slice to logit slice
10
11
func toTokens(values []float64) []token {
	tokens := make([]token, len(values))
12
	for i, v := range values {
13
		tokens[i] = token{
14
15
16
17
18
19
20
21
			id:    int32(i),
			value: float32(v),
		}
	}
	return tokens
}

// Helper to compare logit slices
22
func compareLogits(t *testing.T, name string, want []float64, got []token) {
23
24
25
26
27
28
29
30
31
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
		if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
32
33
34
	}
}

35
36
37
38
func TestTemperature(t *testing.T) {
	input := []float64{2, -1, 4, -3, 1, -2, 0}
	want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp

39
	got := temperature(toTokens(input), 0.5)
40
41
42
	compareLogits(t, "Temperature", want, got)
}

43
func TestSoftmax(t *testing.T) {
44
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
45
	got := softmax(toTokens(input))
46

47
48
49
50
51
52
53
	// Check probabilities sum to 1
	var sum float32
	for _, token := range got {
		sum += token.value
	}
	if math.Abs(float64(sum)-1.0) > 1e-6 {
		t.Errorf("probabilities don't sum to 1: got %f", sum)
54
55
	}

56
57
58
59
60
	// Check relative ordering is preserved
	for i := 1; i < len(got); i++ {
		if got[i].value < got[i-1].value {
			t.Errorf("probability ordering not preserved at index %d", i)
		}
61
	}
62
}
63

64
65
func TestTopK(t *testing.T) {
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
66

67
	// Test k=3
68
	got := topK(toTokens(input), 3)
69
70
	if len(got) != 3 {
		t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
71
	}
72
73
74
75
76
	// Should keep highest 3 values: 4, 2, 1
	want := []float64{4, 2, 1}
	compareLogits(t, "topK(3)", want, got)

	// Test k > len
77
	got = topK(toTokens(input), 10)
78
	compareLogits(t, "topK(10)", input, got)
79
80
81
}

func TestTopP(t *testing.T) {
82
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
83
	tokens := toTokens(input)
84
85
86
87
88
89
90
91
92
93
94
95
96

	// First apply temperature and softmax to get probabilities
	tokens = temperature(tokens, 1)
	tokens = softmax(tokens)
	sortLogits(tokens)

	// Then apply topP
	got := topP(tokens, 0.95)

	// Should keep tokens until cumsum > 0.95
	if len(got) > 3 {
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
		t.Logf("got: %v", got)
97
98
99
100
	}
}

func TestMinP(t *testing.T) {
101
	input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
102
	tokens := toTokens(input)
103
104
105
106
107
108
109
110
111
112
113

	// First apply temperature and softmax
	tokens = temperature(tokens, 1)
	tokens = softmax(tokens)

	// Then apply minP
	got := minP(tokens, 0.2)

	// Should keep tokens with prob >= 0.2 * max_prob
	if len(got) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
114
115
116
	}
}

117
118
func TestSortLogits(t *testing.T) {
	input := []float64{3, 1, 4, 2, -1, 0, -2}
119
	tokens := toTokens(input)
120
121

	sortLogits(tokens)
122

123
124
125
126
127
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > tokens[i-1].value {
			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
				i, tokens[i].value, tokens[i-1].value)
		}
128
129
	}

130
131
132
133
134
135
	want := []float64{4, 3, 2, 1, 0, -1, -2}
	compareLogits(t, "sortLogits", want, tokens)
}

func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
136
	tokens := make([]token, 1<<16)
137
	for i := range tokens {
138
		tokens[i] = token{
139
140
141
			id:    int32(i),
			value: rand.Float32(),
		}
142
	}
143

144
	tokensCopy := make([]token, len(tokens))
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topK(tokensCopy, 10)
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topP(tokensCopy, 0.9)
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			minP(tokensCopy, 0.2)
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			sortLogits(tokensCopy)
		}
	})
185
}