transforms_test.go 2.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
package sample

import (
	"math"
	"math/rand/v2"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func TestTemperature(t *testing.T) {
	got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
	want := []float64{-4, -10, 0, -14, -6, -12, -8}
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("logits mismatch (-want +got):\n%s", diff)
	}
}

func TestSoftmax(t *testing.T) {
	got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})

	want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("probs mismatch (-want +got):\n%s", diff)
	}
}

func TestTopK(t *testing.T) {
	got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("logits mismatch (-want +got):\n%s", diff)
	}

	got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})

	want = []float64{-3, -2, -1, 0, 1, 2, 4}
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("logits mismatch (-want +got):\n%s", diff)
	}
}

func TestTopP(t *testing.T) {
	got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("logits mismatch (-want +got):\n%s", diff)
	}
}

func TestMinP(t *testing.T) {
	got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("logits mismatch (-want +got):\n%s", diff)
	}
}

func BenchmarkTransform(b *testing.B) {
	transforms := map[string]Transform{
		"Temperature": Temperature(0.5),
		"TopK":        TopK(10),
		"TopP":        TopP(0.9),
		"MinP":        MinP(0.2),
	}

	logits := make([]float64, 1<<16)
	for i := range logits {
		logits[i] = rand.Float64()
	}

	for name, transform := range transforms {
		b.Run(name, func(b *testing.B) {
			b.ResetTimer()
			for range b.N {
				transform.Apply(logits)
			}
		})
	}
}