samplers.go 4.63 KB
Newer Older
1
2
3
4
package sample

import (
	"errors"
5
	"math"
6
7
	"math/rand/v2"
	"slices"
8
	"sync"
9

10
11
	"github.com/ollama/ollama/llama"
)
12

13
14
// token represents information about a single token during sampling
type token struct {
15
16
17
18
	id    int32   // The token's unique identifier
	value float32 // The raw logit or probability from the model
}

19
type Sampler struct {
20
21
22
23
24
	rng         *rand.Rand
	topK        int
	topP        float32
	minP        float32
	temperature float32
25
	grammar     *Grammar
26
27
}

28
29
30
31
32
func (s *Sampler) Sample(logits []float32) (int32, error) {
	tokens := make([]token, len(logits))
	for i := range logits {
		tokens[i].id = int32(i)
		tokens[i].value = logits[i]
33
34
	}

35
36
37
38
	t, err := s.sample(tokens)
	if err != nil {
		return -1, err
	}
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	if s.grammar != nil {
		// optimization: first check if the max logit is accepted by the grammar
		// if the max logit is rejected, apply the grammar to all logits (slower)
		top := []token{t}
		s.grammar.Apply(top)
		if !math.IsInf(float64(top[0].value), -1) {
			s.grammar.Accept(top[0].id)
			return top[0].id, nil
		}

		// since .sample has side effects of modifying the tokens
		// we need to reset them before applying the grammar and
		// sampling again
		for i := range logits {
			tokens[i].id = int32(i)
			tokens[i].value = logits[i]
		}
		s.grammar.Apply(tokens)
		t, err = s.sample(tokens)
		if err != nil {
			return -1, err
		}
		s.grammar.Accept(t.id)
	}

	return t.id, nil
}

// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
	max := tokens[0]
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > max.value {
			max = tokens[i]
		}
	}

	return max
}

// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
	if s.temperature == 0 {
		return greedy(tokens), nil
85
86
	}

ParthSareen's avatar
ParthSareen committed
87
88
	// topK also sorts the tokens in descending order of logits
	tokens = topK(tokens, s.topK)
89

90
91
92
	// scale and normalize the tokens in place
	temperature(tokens, s.temperature)
	softmax(tokens)
93

94
95
96
	tokens = topP(tokens, s.topP)
	tokens = minP(tokens, s.minP)

97
98
99
	// TODO: this should fall back to greedy sampling
	// or topP, topK values etc should be such that
	// there are always tokens to sample from
100
	if len(tokens) == 0 {
101
		return token{}, errors.New("no tokens to sample from")
102
103
	}

104
105
106
107
108
	var r float32
	if s.rng != nil {
		r = s.rng.Float32()
	} else {
		r = rand.Float32()
109
110
	}

111
112
113
114
115
	// Calculate cumulative sum of probabilities
	var sum float32
	for i := range tokens {
		sum += tokens[i].value
		tokens[i].value = sum
116
	}
117
	r *= tokens[len(tokens)-1].value
118

119
	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
120
121
122
123
124
		if token.value < target {
			return -1
		}
		return 1
	})
125

126
	return tokens[idx], nil
127
128
129
}

// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
130
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
131
132
133
134
135
136
137
	var rng *rand.Rand
	if seed != -1 {
		// PCG requires two parameters: sequence and stream
		// Use original seed for sequence
		sequence := uint64(seed)
		// Use golden ratio hash to generate statistically independent seeds
		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
138
	}
139
140
141
	if temperature < 0.0 {
		temperature = 0.0
	}
142

143
144
	if topP < 0.0 {
		topP = 0.0
145
	}
146
147
	if topP >= 1.0 {
		topP = 1.0
148
149
	}

150
151
152
153
154
	if minP < 0.0 {
		minP = 0.0
	}
	if minP >= 1.0 {
		minP = 1.0
155
156
	}

157
	return Sampler{
158
159
160
161
162
		rng:         rng,
		topK:        topK,
		topP:        topP,
		minP:        minP,
		temperature: temperature,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
		grammar:     grammar,
	}
}

type Grammar struct {
	vocab   *Vocab
	grammar string
	sampler *llama.Sampler
}

func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
	v, err := vocab.Load()
	if err != nil {
		return nil, err
	}

	return &Grammar{
		vocab:   vocab,
		grammar: grammar,
		sampler: llama.NewGrammarSampler(v, grammar),
	}, nil
}

func (g *Grammar) Apply(tokens []token) {
	tds := make([]llama.TokenData, len(tokens))
	for i, token := range tokens {
		tds[i].Id = token.id
		tds[i].Logit = token.value
191
	}
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

	g.sampler.Apply(tds)

	for i := range tokens {
		tokens[i].value = tds[i].Logit
	}
}

func (g *Grammar) Accept(token int32) {
	g.sampler.Accept(token)
}

type Vocab struct {
	once  sync.Once
	vocab *llama.Vocab
	err   error
	path  string
}

func NewVocab(path string) *Vocab {
	return &Vocab{path: path}
}

// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
	v.once.Do(func() {
		vocab, err := llama.LoadVocabFromFile(v.path)
		if err != nil {
			v.err = err
			return
		}
		v.vocab = vocab
	})
	return v.vocab, v.err
226
}