samplers.go 4.6 KB
Newer Older
1
2
3
4
package sample

import (
	"errors"
5
	"math"
6
7
	"math/rand/v2"
	"slices"
8
	"sync"
9

10
11
	"github.com/ollama/ollama/llama"
)
12

13
14
// token represents information about a single token during sampling
type token struct {
15
16
17
18
	id    int32   // The token's unique identifier
	value float32 // The raw logit or probability from the model
}

19
type Sampler struct {
20
21
22
23
24
	rng         *rand.Rand
	topK        int
	topP        float32
	minP        float32
	temperature float32
25
	grammar     *Grammar
26
27
}

28
29
30
31
32
func (s *Sampler) Sample(logits []float32) (int32, error) {
	tokens := make([]token, len(logits))
	for i := range logits {
		tokens[i].id = int32(i)
		tokens[i].value = logits[i]
33
34
	}

35
36
37
38
	t, err := s.sample(tokens)
	if err != nil {
		return -1, err
	}
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	if s.grammar != nil {
		// optimization: first check if the max logit is accepted by the grammar
		// if the max logit is rejected, apply the grammar to all logits (slower)
		top := []token{t}
		s.grammar.Apply(top)
		if !math.IsInf(float64(top[0].value), -1) {
			s.grammar.Accept(top[0].id)
			return top[0].id, nil
		}

		// since .sample has side effects of modifying the tokens
		// we need to reset them before applying the grammar and
		// sampling again
		for i := range logits {
			tokens[i].id = int32(i)
			tokens[i].value = logits[i]
		}
		s.grammar.Apply(tokens)
		t, err = s.sample(tokens)
		if err != nil {
			return -1, err
		}
		s.grammar.Accept(t.id)
	}

	return t.id, nil
}

// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
	max := tokens[0]
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > max.value {
			max = tokens[i]
		}
	}

	return max
}

// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
	if s.temperature == 0 {
		return greedy(tokens), nil
85
86
	}

87
88
89
90
	if s.topK > 0 {
		tokens = topK(tokens, s.topK)
	} else {
		sortLogits(tokens)
91
92
	}

93
94
95
96
97
	tokens = temperature(tokens, s.temperature)
	tokens = softmax(tokens)
	tokens = topP(tokens, s.topP)
	tokens = minP(tokens, s.minP)

98
99
100
	// TODO: this should fall back to greedy sampling
	// or topP, topK values etc should be such that
	// there are always tokens to sample from
101
	if len(tokens) == 0 {
102
		return token{}, errors.New("no tokens to sample from")
103
104
	}

105
106
107
108
109
	var r float32
	if s.rng != nil {
		r = s.rng.Float32()
	} else {
		r = rand.Float32()
110
111
	}

112
113
114
115
116
	// Calculate cumulative sum of probabilities
	var sum float32
	for i := range tokens {
		sum += tokens[i].value
		tokens[i].value = sum
117
	}
118
	r *= tokens[len(tokens)-1].value
119

120
	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
121
122
123
124
125
		if token.value < target {
			return -1
		}
		return 1
	})
126

127
	return tokens[idx], nil
128
129
130
}

// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
131
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
132
133
134
135
136
137
138
	var rng *rand.Rand
	if seed != -1 {
		// PCG requires two parameters: sequence and stream
		// Use original seed for sequence
		sequence := uint64(seed)
		// Use golden ratio hash to generate statistically independent seeds
		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
139
	}
140
141
142
	if temperature < 0.0 {
		temperature = 0.0
	}
143

144
145
	if topP < 0.0 {
		topP = 0.0
146
	}
147
148
	if topP >= 1.0 {
		topP = 1.0
149
150
	}

151
152
153
154
155
	if minP < 0.0 {
		minP = 0.0
	}
	if minP >= 1.0 {
		minP = 1.0
156
157
	}

158
	return Sampler{
159
160
161
162
163
		rng:         rng,
		topK:        topK,
		topP:        topP,
		minP:        minP,
		temperature: temperature,
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
		grammar:     grammar,
	}
}

type Grammar struct {
	vocab   *Vocab
	grammar string
	sampler *llama.Sampler
}

func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
	v, err := vocab.Load()
	if err != nil {
		return nil, err
	}

	return &Grammar{
		vocab:   vocab,
		grammar: grammar,
		sampler: llama.NewGrammarSampler(v, grammar),
	}, nil
}

func (g *Grammar) Apply(tokens []token) {
	tds := make([]llama.TokenData, len(tokens))
	for i, token := range tokens {
		tds[i].Id = token.id
		tds[i].Logit = token.value
192
	}
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

	g.sampler.Apply(tds)

	for i := range tokens {
		tokens[i].value = tds[i].Logit
	}
}

func (g *Grammar) Accept(token int32) {
	g.sampler.Accept(token)
}

type Vocab struct {
	once  sync.Once
	vocab *llama.Vocab
	err   error
	path  string
}

func NewVocab(path string) *Vocab {
	return &Vocab{path: path}
}

// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
	v.once.Do(func() {
		vocab, err := llama.LoadVocabFromFile(v.path)
		if err != nil {
			v.err = err
			return
		}
		v.vocab = vocab
	})
	return v.vocab, v.err
227
}