transforms_test.go 3.98 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math"
	"math/rand/v2"
	"testing"
)

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// Helper to convert float64 slice to logit slice
func toLogits(values []float64) []logit {
	tokens := make([]logit, len(values))
	for i, v := range values {
		tokens[i] = logit{
			id:    int32(i),
			value: float32(v),
		}
	}
	return tokens
}

// Helper to compare logit slices
func compareLogits(t *testing.T, name string, want []float64, got []logit) {
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
		if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
32
33
34
	}
}

35
36
37
38
39
40
41
42
func TestTemperature(t *testing.T) {
	input := []float64{2, -1, 4, -3, 1, -2, 0}
	want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp

	got := temperature(toLogits(input), 0.5)
	compareLogits(t, "Temperature", want, got)
}

43
func TestSoftmax(t *testing.T) {
44
45
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
	got := softmax(toLogits(input))
46

47
48
49
50
51
52
53
	// Check probabilities sum to 1
	var sum float32
	for _, token := range got {
		sum += token.value
	}
	if math.Abs(float64(sum)-1.0) > 1e-6 {
		t.Errorf("probabilities don't sum to 1: got %f", sum)
54
55
	}

56
57
58
59
60
	// Check relative ordering is preserved
	for i := 1; i < len(got); i++ {
		if got[i].value < got[i-1].value {
			t.Errorf("probability ordering not preserved at index %d", i)
		}
61
	}
62
}
63

64
65
func TestTopK(t *testing.T) {
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
66

67
68
69
70
	// Test k=3
	got := topK(toLogits(input), 3)
	if len(got) != 3 {
		t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
71
	}
72
73
74
75
76
77
78
	// Should keep highest 3 values: 4, 2, 1
	want := []float64{4, 2, 1}
	compareLogits(t, "topK(3)", want, got)

	// Test k > len
	got = topK(toLogits(input), 10)
	compareLogits(t, "topK(10)", input, got)
79
80
81
}

func TestTopP(t *testing.T) {
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
	tokens := toLogits(input)

	// First apply temperature and softmax to get probabilities
	tokens = temperature(tokens, 1)
	tokens = softmax(tokens)
	sortLogits(tokens)

	// Then apply topP
	got := topP(tokens, 0.95)

	// Should keep tokens until cumsum > 0.95
	if len(got) > 3 {
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
		t.Logf("got: %v", got)
97
98
99
100
	}
}

func TestMinP(t *testing.T) {
101
102
103
104
105
106
107
108
109
110
111
112
113
	input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
	tokens := toLogits(input)

	// First apply temperature and softmax
	tokens = temperature(tokens, 1)
	tokens = softmax(tokens)

	// Then apply minP
	got := minP(tokens, 0.2)

	// Should keep tokens with prob >= 0.2 * max_prob
	if len(got) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
114
115
116
	}
}

117
118
119
120
121
func TestSortLogits(t *testing.T) {
	input := []float64{3, 1, 4, 2, -1, 0, -2}
	tokens := toLogits(input)

	sortLogits(tokens)
122

123
124
125
126
127
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > tokens[i-1].value {
			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
				i, tokens[i].value, tokens[i-1].value)
		}
128
129
	}

130
131
132
133
134
135
136
137
138
139
140
141
	want := []float64{4, 3, 2, 1, 0, -1, -2}
	compareLogits(t, "sortLogits", want, tokens)
}

func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
	tokens := make([]logit, 1<<16)
	for i := range tokens {
		tokens[i] = logit{
			id:    int32(i),
			value: rand.Float32(),
		}
142
	}
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

	tokensCopy := make([]logit, len(tokens))

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topK(tokensCopy, 10)
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topP(tokensCopy, 0.9)
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			minP(tokensCopy, 0.2)
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			sortLogits(tokensCopy)
		}
	})
185
}