samplers.go 4.61 KB
Newer Older
1
2
3
4
package sample

import (
	"errors"
5
	"math"
6
7
	"math/rand/v2"
	"slices"
8
	"sync"
9

10
11
	"github.com/ollama/ollama/llama"
)
12

13
14
// token represents information about a single token during sampling
type token struct {
15
16
17
18
	id    int32   // The token's unique identifier
	value float32 // The raw logit or probability from the model
}

19
type Sampler struct {
20
21
22
23
24
	rng         *rand.Rand
	topK        int
	topP        float32
	minP        float32
	temperature float32
25
	grammar     *Grammar
26
27
}

28
func (s *Sampler) Sample(logits []float32) (int32, error) {
29
30
31
32
	if len(logits) == 0 {
		return -1, errors.New("sample: no logits provided to sample")
	}

33
34
35
36
	tokens := make([]token, len(logits))
	for i := range logits {
		tokens[i].id = int32(i)
		tokens[i].value = logits[i]
37
38
	}

39
40
41
42
	t, err := s.sample(tokens)
	if err != nil {
		return -1, err
	}
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
	if s.grammar != nil {
		// optimization: first check if the max logit is accepted by the grammar
		// if the max logit is rejected, apply the grammar to all logits (slower)
		top := []token{t}
		s.grammar.Apply(top)
		if !math.IsInf(float64(top[0].value), -1) {
			s.grammar.Accept(top[0].id)
			return top[0].id, nil
		}

		// since .sample has side effects of modifying the tokens
		// we need to reset them before applying the grammar and
		// sampling again
		for i := range logits {
			tokens[i].id = int32(i)
			tokens[i].value = logits[i]
		}
		s.grammar.Apply(tokens)
		t, err = s.sample(tokens)
		if err != nil {
			return -1, err
		}
		s.grammar.Accept(t.id)
	}

	return t.id, nil
}

// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
	max := tokens[0]
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > max.value {
			max = tokens[i]
		}
	}

	return max
}

// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
	if s.temperature == 0 {
		return greedy(tokens), nil
89
90
	}

ParthSareen's avatar
ParthSareen committed
91
92
	// topK also sorts the tokens in descending order of logits
	tokens = topK(tokens, s.topK)
93

94
95
96
	// scale and normalize the tokens in place
	temperature(tokens, s.temperature)
	softmax(tokens)
97

98
99
100
101
102
103
104
105
	tokens = topP(tokens, s.topP)
	tokens = minP(tokens, s.minP)

	var r float32
	if s.rng != nil {
		r = s.rng.Float32()
	} else {
		r = rand.Float32()
106
107
	}

108
109
110
111
112
	// Calculate cumulative sum of probabilities
	var sum float32
	for i := range tokens {
		sum += tokens[i].value
		tokens[i].value = sum
113
	}
114
	r *= tokens[len(tokens)-1].value
115

116
	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
117
118
119
120
121
		if token.value < target {
			return -1
		}
		return 1
	})
122

123
124
125
	if math.IsNaN(float64(sum)) {
		return token{}, errors.New("sample: logits sum to NaN, check model output")
	}
126
	return tokens[idx], nil
127
128
129
}

// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
130
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
131
132
133
134
135
136
137
	var rng *rand.Rand
	if seed != -1 {
		// PCG requires two parameters: sequence and stream
		// Use original seed for sequence
		sequence := uint64(seed)
		// Use golden ratio hash to generate statistically independent seeds
		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
138
	}
139
140
141
	if temperature < 0.0 {
		temperature = 0.0
	}
142

143
144
	if topP < 0.0 {
		topP = 0.0
145
	}
146
147
	if topP >= 1.0 {
		topP = 1.0
148
149
	}

150
151
152
153
154
	if minP < 0.0 {
		minP = 0.0
	}
	if minP >= 1.0 {
		minP = 1.0
155
156
	}

157
	return Sampler{
158
159
160
161
162
		rng:         rng,
		topK:        topK,
		topP:        topP,
		minP:        minP,
		temperature: temperature,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
		grammar:     grammar,
	}
}

type Grammar struct {
	vocab   *Vocab
	grammar string
	sampler *llama.Sampler
}

func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
	v, err := vocab.Load()
	if err != nil {
		return nil, err
	}

	return &Grammar{
		vocab:   vocab,
		grammar: grammar,
		sampler: llama.NewGrammarSampler(v, grammar),
	}, nil
}

func (g *Grammar) Apply(tokens []token) {
	tds := make([]llama.TokenData, len(tokens))
	for i, token := range tokens {
		tds[i].Id = token.id
		tds[i].Logit = token.value
191
	}
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

	g.sampler.Apply(tds)

	for i := range tokens {
		tokens[i].value = tds[i].Logit
	}
}

func (g *Grammar) Accept(token int32) {
	g.sampler.Accept(token)
}

type Vocab struct {
	once  sync.Once
	vocab *llama.Vocab
	err   error
	path  string
}

func NewVocab(path string) *Vocab {
	return &Vocab{path: path}
}

// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
	v.once.Do(func() {
		vocab, err := llama.LoadVocabFromFile(v.path)
		if err != nil {
			v.err = err
			return
		}
		v.vocab = vocab
	})
	return v.vocab, v.err
226
}