samplers.go 2.94 KB
Newer Older
1
2
3
4
package sample

import (
	"errors"
5
6
	"math/rand/v2"
	"slices"
7
8
)

9
// Sampler is not thread-safe. Each goroutine should have its own instance
10
11
12
13
type Sampler interface {
	Sample([]float32) (int32, error)
}

14
15
16
17
18
19
// logit represents information about a single token during sampling
type logit struct {
	id    int32   // The token's unique identifier
	value float32 // The raw logit or probability from the model
}

20
type weighted struct {
21
22
23
24
25
26
	rng         *rand.Rand
	tokens      []logit
	topK        int
	topP        float32
	minP        float32
	temperature float32
27
28
}

29
30
31
func (s *weighted) Sample(logits []float32) (int32, error) {
	if len(s.tokens) < len(logits) {
		s.tokens = make([]logit, len(logits))
32
33
	}

34
35
	tokens := s.tokens[:len(logits)]

36
	for i, v := range logits {
37
38
		tokens[i].id = int32(i)
		tokens[i].value = v
39
40
	}

41
42
43
44
45
	// Tokens are sorted by logits in TopK or SortTokens
	if s.topK > 0 {
		tokens = topK(tokens, s.topK)
	} else {
		sortLogits(tokens)
46
47
	}

48
49
50
51
52
53
54
55
	tokens = temperature(tokens, s.temperature)
	tokens = softmax(tokens)

	tokens = topP(tokens, s.topP)
	tokens = minP(tokens, s.minP)

	if len(tokens) == 0 {
		return -1, errors.New("no valid logits found for weighted sampling")
56
57
	}

58
59
60
61
62
	var r float32
	if s.rng != nil {
		r = s.rng.Float32()
	} else {
		r = rand.Float32()
63
64
	}

65
66
67
68
69
	// Calculate cumulative sum of probabilities
	var sum float32
	for i := range tokens {
		sum += tokens[i].value
		tokens[i].value = sum
70
	}
71
	r *= tokens[len(tokens)-1].value
72

73
74
75
76
77
78
79
80
	idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
		// Compare cumulative probabilities
		if token.value < target {
			return -1
		}
		// First token that exceeds target
		return 1
	})
81

82
83
84
85
86
	if idx >= len(tokens) {
		idx = len(tokens) - 1
	}

	return tokens[idx].id, nil
87
88
}

89
90
91
type greedy struct{}

// Greedy sample returns the index of the maximum value in logits.
92
func (s greedy) Sample(logits []float32) (int32, error) {
93
94
	if len(logits) == 0 {
		return -1, errors.New("no logits provided for greedy sampling")
95
96
	}

97
	maxIdx := 0
98
99
100
101
	maxVal := logits[0]
	for i := 1; i < len(logits); i++ {
		if logits[i] > maxVal {
			maxVal = logits[i]
102
103
104
105
106
107
108
109
			maxIdx = i
		}
	}

	return int32(maxIdx), nil
}

// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
110
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
111
	if temperature == 0 {
112
		return &greedy{}
113
114
	}

115
116
117
118
119
120
121
	var rng *rand.Rand
	if seed != -1 {
		// PCG requires two parameters: sequence and stream
		// Use original seed for sequence
		sequence := uint64(seed)
		// Use golden ratio hash to generate statistically independent seeds
		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
122
	}
123
	temperature = max(temperature, 1)
124

125
126
	if topP < 0.0 {
		topP = 0.0
127
	}
128
129
	if topP >= 1.0 {
		topP = 1.0
130
131
	}

132
133
134
135
136
	if minP < 0.0 {
		minP = 0.0
	}
	if minP >= 1.0 {
		minP = 1.0
137
138
	}

139
140
141
142
143
144
	return &weighted{
		rng:         rng,
		topK:        topK,
		topP:        topP,
		minP:        minP,
		temperature: temperature,
145
146
	}
}