samplers_test.go 1.83 KB
Newer Older
1
2
3
package sample

import (
4
	"math"
5
6
7
8
9
	"math/rand/v2"
	"testing"
)

func TestWeighted(t *testing.T) {
10
	logits := []float32{-10, 3, -10, -10}
11
	sampler := NewSampler(0, 0, 0, 0, 0, nil)
12
	got, err := sampler.Sample(logits)
13
14
15
16
17
18
19
20
21
	if err != nil {
		t.Error(err)
		return
	}
	want := int32(1)
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}

22
	logits = []float32{-100, -10, 0, 10}
23
	sampler = NewSampler(0, 0, 0, 0, 0, nil)
24
	got, err = sampler.Sample(logits)
25
26
27
28
	if err != nil {
		t.Error(err)
		return
	}
29
	want = int32(3) // Should pick highest probability with this r value
30
31
32
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

	// Test very high p
	logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
	// Use extremely small topP to filter out all tokens
	sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
	got, err = sampler.Sample(logits)
	if err != nil {
		t.Error(err)
		return
	}
	// Should get the token with the highest logit
	want = int32(0)
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}

	logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
	sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
	got, err = sampler.Sample(logits)
	if err == nil {
		t.Errorf("expected error, got %d", got)
		return
	}
56
57
58
59
}

func BenchmarkSample(b *testing.B) {
	samplers := map[string]Sampler{
60
61
		"Greedy":   NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
		"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
62
63
	}

64
	// Generate random logits for benchmarking
65
66
67
68
69
70
71
72
	logits := make([]float32, 1<<16)
	for i := range logits {
		logits[i] = rand.Float32()
	}

	for name, s := range samplers {
		b.Run(name, func(b *testing.B) {
			b.ResetTimer()
73
			for b.Loop() {
74
				if _, err := s.Sample(logits); err != nil {
75
					b.Fatalf("error sampling: %v", err)
76
77
78
79
80
				}
			}
		})
	}
}