samplers.go 4.63 KB
Newer Older
1
2
3
4
package sample

import (
	"errors"
5
	"math"
6
7
	"math/rand/v2"
	"slices"
8
	"sync"
9

10
11
	"github.com/ollama/ollama/llama"
)
12

13
14
// token represents information about a single token during sampling
type token struct {
15
16
17
18
	id    int32   // The token's unique identifier
	value float32 // The raw logit or probability from the model
}

19
type Sampler struct {
20
21
22
23
24
	rng         *rand.Rand
	topK        int
	topP        float32
	minP        float32
	temperature float32
25
	grammar     *Grammar
26
27
}

28
29
30
31
32
func (s *Sampler) Sample(logits []float32) (int32, error) {
	tokens := make([]token, len(logits))
	for i := range logits {
		tokens[i].id = int32(i)
		tokens[i].value = logits[i]
33
34
	}

35
36
37
38
	t, err := s.sample(tokens)
	if err != nil {
		return -1, err
	}
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	if s.grammar != nil {
		// optimization: first check if the max logit is accepted by the grammar
		// if the max logit is rejected, apply the grammar to all logits (slower)
		top := []token{t}
		s.grammar.Apply(top)
		if !math.IsInf(float64(top[0].value), -1) {
			s.grammar.Accept(top[0].id)
			return top[0].id, nil
		}

		// since .sample has side effects of modifying the tokens
		// we need to reset them before applying the grammar and
		// sampling again
		for i := range logits {
			tokens[i].id = int32(i)
			tokens[i].value = logits[i]
		}
		s.grammar.Apply(tokens)
		t, err = s.sample(tokens)
		if err != nil {
			return -1, err
		}
		s.grammar.Accept(t.id)
	}

	return t.id, nil
}

// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
	max := tokens[0]
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > max.value {
			max = tokens[i]
		}
	}

	return max
}

// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
	if s.temperature == 0 {
		return greedy(tokens), nil
85
86
	}

ParthSareen's avatar
ParthSareen committed
87
88
	// topK also sorts the tokens in descending order of logits
	tokens = topK(tokens, s.topK)
89

90
	// token logit values are updated to probabilities
91
	tokens = temperature(tokens, s.temperature)
92

93
94
95
	tokens = topP(tokens, s.topP)
	tokens = minP(tokens, s.minP)

96
97
98
	// TODO: this should fall back to greedy sampling
	// or topP, topK values etc should be such that
	// there are always tokens to sample from
99
	if len(tokens) == 0 {
100
		return token{}, errors.New("no tokens to sample from")
101
102
	}

103
104
105
106
107
	var r float32
	if s.rng != nil {
		r = s.rng.Float32()
	} else {
		r = rand.Float32()
108
109
	}

110
111
112
113
114
	// Calculate cumulative sum of probabilities
	var sum float32
	for i := range tokens {
		sum += tokens[i].value
		tokens[i].value = sum
115
	}
116
	r *= tokens[len(tokens)-1].value
117

118
	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
119
120
121
122
123
		if token.value < target {
			return -1
		}
		return 1
	})
124

125
	return tokens[idx], nil
126
127
128
}

// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
129
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
130
131
132
133
134
135
136
	var rng *rand.Rand
	if seed != -1 {
		// PCG requires two parameters: sequence and stream
		// Use original seed for sequence
		sequence := uint64(seed)
		// Use golden ratio hash to generate statistically independent seeds
		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
137
	}
138
139
140
	if temperature < 0.0 {
		temperature = 0.0
	}
141

142
143
	if topP < 0.0 {
		topP = 0.0
144
	}
145
146
	if topP >= 1.0 {
		topP = 1.0
147
148
	}

149
150
151
152
153
	if minP < 0.0 {
		minP = 0.0
	}
	if minP >= 1.0 {
		minP = 1.0
154
155
	}

156
	return Sampler{
157
158
159
160
161
		rng:         rng,
		topK:        topK,
		topP:        topP,
		minP:        minP,
		temperature: temperature,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
		grammar:     grammar,
	}
}

type Grammar struct {
	vocab   *Vocab
	grammar string
	sampler *llama.Sampler
}

func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
	v, err := vocab.Load()
	if err != nil {
		return nil, err
	}

	return &Grammar{
		vocab:   vocab,
		grammar: grammar,
		sampler: llama.NewGrammarSampler(v, grammar),
	}, nil
}

func (g *Grammar) Apply(tokens []token) {
	tds := make([]llama.TokenData, len(tokens))
	for i, token := range tokens {
		tds[i].Id = token.id
		tds[i].Logit = token.value
190
	}
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

	g.sampler.Apply(tds)

	for i := range tokens {
		tokens[i].value = tds[i].Logit
	}
}

func (g *Grammar) Accept(token int32) {
	g.sampler.Accept(token)
}

type Vocab struct {
	once  sync.Once
	vocab *llama.Vocab
	err   error
	path  string
}

func NewVocab(path string) *Vocab {
	return &Vocab{path: path}
}

// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
	v.once.Do(func() {
		vocab, err := llama.LoadVocabFromFile(v.path)
		if err != nil {
			v.err = err
			return
		}
		v.vocab = vocab
	})
	return v.vocab, v.err
225
}