transforms_test.go 3.75 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math"
	"math/rand/v2"
	"testing"
)

9
// Helper to convert float64 slice to logit slice
10
11
func toTokens(values []float64) []token {
	tokens := make([]token, len(values))
12
	for i, v := range values {
13
		tokens[i] = token{
14
15
16
17
18
19
20
21
			id:    int32(i),
			value: float32(v),
		}
	}
	return tokens
}

// Helper to compare logit slices
22
func compareLogits(t *testing.T, name string, want []float64, got []token) {
23
24
25
26
27
28
29
30
31
	t.Helper()
	if len(want) != len(got) {
		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
		return
	}
	for i := range want {
		if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
		}
32
33
34
	}
}

35
36
func TestTemperatureAndSoftmax(t *testing.T) {
	input := []float64{1, 4, -2, 0}
37
	got := temperature(toTokens(input), 0.5)
38

39
40
41
42
43
44
45
	// Check probabilities sum to 1
	var sum float32
	for _, token := range got {
		sum += token.value
	}
	if math.Abs(float64(sum)-1.0) > 1e-6 {
		t.Errorf("probabilities don't sum to 1: got %f", sum)
46
47
	}

48
49
50
51
52
53
54
55
	got = temperature(toTokens(input), 1)
	// Check probabilities sum to 1
	sum = 0.0
	for _, token := range got {
		sum += token.value
	}
	if math.Abs(float64(sum)-1.0) > 1e-6 {
		t.Errorf("probabilities don't sum to 1: got %f", sum)
56
	}
57
}
58

59
60
func TestTopK(t *testing.T) {
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
61

62
	// Test k=3
63
	got := topK(toTokens(input), 3)
64
65
	if len(got) != 3 {
		t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
66
	}
67
68
69
70
71
	// Should keep highest 3 values: 4, 2, 1
	want := []float64{4, 2, 1}
	compareLogits(t, "topK(3)", want, got)

	// Test k > len
72
	got = topK(toTokens(input), 10)
73
	compareLogits(t, "topK(10)", input, got)
74
75
76
}

func TestTopP(t *testing.T) {
77
	input := []float64{-3, -2, -1, 0, 1, 2, 4}
78
	tokens := toTokens(input)
79
80
81
82
83
84
85
86
87
88
89
90

	// First apply temperature and softmax to get probabilities
	tokens = temperature(tokens, 1)
	sortLogits(tokens)

	// Then apply topP
	got := topP(tokens, 0.95)

	// Should keep tokens until cumsum > 0.95
	if len(got) > 3 {
		t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
		t.Logf("got: %v", got)
91
92
93
94
	}
}

func TestMinP(t *testing.T) {
95
	input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
96
	tokens := toTokens(input)
97
98
99
100
101
102
103
104
105
106

	// First apply temperature and softmax
	tokens = temperature(tokens, 1)

	// Then apply minP
	got := minP(tokens, 0.2)

	// Should keep tokens with prob >= 0.2 * max_prob
	if len(got) > 3 {
		t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
107
108
109
	}
}

110
111
func TestSortLogits(t *testing.T) {
	input := []float64{3, 1, 4, 2, -1, 0, -2}
112
	tokens := toTokens(input)
113
114

	sortLogits(tokens)
115

116
117
118
119
120
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > tokens[i-1].value {
			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
				i, tokens[i].value, tokens[i-1].value)
		}
121
122
	}

123
124
125
126
127
128
	want := []float64{4, 3, 2, 1, 0, -1, -2}
	compareLogits(t, "sortLogits", want, tokens)
}

func BenchmarkTransforms(b *testing.B) {
	// Generate random logits
129
	tokens := make([]token, 1<<16)
130
	for i := range tokens {
131
		tokens[i] = token{
132
133
134
			id:    int32(i),
			value: rand.Float32(),
		}
135
	}
136

137
	tokensCopy := make([]token, len(tokens))
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

	b.Run("Temperature", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			temperature(tokensCopy, 0.5)
		}
	})

	b.Run("TopK", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topK(tokensCopy, 10)
		}
	})

	b.Run("TopP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			topP(tokensCopy, 0.9)
		}
	})

	b.Run("MinP", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			minP(tokensCopy, 0.2)
		}
	})

	b.Run("SortTokens", func(b *testing.B) {
		b.ResetTimer()
		for b.Loop() {
			copy(tokensCopy, tokens)
			sortLogits(tokensCopy)
		}
	})
178
}