samplers_test.go 1.19 KB
Newer Older
1
2
3
4
5
6
7
8
package sample

import (
	"math/rand/v2"
	"testing"
)

func TestWeighted(t *testing.T) {
9
	logits := []float32{-10, 3, -10, -10}
10
	sampler := NewSampler(0, 0, 0, 0, 0, nil)
11
	got, err := sampler.Sample(logits)
12
13
14
15
16
17
18
19
20
	if err != nil {
		t.Error(err)
		return
	}
	want := int32(1)
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}

21
	logits = []float32{-100, -10, 0, 10}
22
	sampler = NewSampler(0, 0, 0, 0, 0, nil)
23
	got, err = sampler.Sample(logits)
24
25
26
27
	if err != nil {
		t.Error(err)
		return
	}
28
	want = int32(3) // Should pick highest probability with this r value
29
30
31
32
33
34
35
	if want != got {
		t.Errorf("index mismatch: want %d, got %d", want, got)
	}
}

func BenchmarkSample(b *testing.B) {
	samplers := map[string]Sampler{
36
37
		"Greedy":   NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
		"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
38
39
	}

40
	// Generate random logits for benchmarking
41
42
43
44
45
46
47
48
	logits := make([]float32, 1<<16)
	for i := range logits {
		logits[i] = rand.Float32()
	}

	for name, s := range samplers {
		b.Run(name, func(b *testing.B) {
			b.ResetTimer()
49
			for b.Loop() {
50
				if _, err := s.Sample(logits); err != nil {
51
					b.Fatalf("error sampling: %v", err)
52
53
54
55
56
				}
			}
		})
	}
}