"vscode:/vscode.git/clone" did not exist on "8f9ccb9ac6d2a28fbecc1c46408fae74dfe6b393"
samplers.go 4.65 KB
Newer Older
1
2
3
4
package sample

import (
	"errors"
5
	"math"
6
7
	"math/rand/v2"
	"slices"
8

9
	"github.com/ollama/ollama/llama"
10
	"github.com/ollama/ollama/model"
11
)
12

13
14
// token represents information about a single token during sampling
type token struct {
15
16
17
18
	id    int32   // The token's unique identifier
	value float32 // The raw logit or probability from the model
}

19
type Sampler struct {
20
21
22
23
24
	rng         *rand.Rand
	topK        int
	topP        float32
	minP        float32
	temperature float32
25
	grammar     *GrammarSampler
26
27
}

28
func (s *Sampler) Sample(logits []float32) (int32, error) {
29
30
31
32
	if len(logits) == 0 {
		return -1, errors.New("sample: no logits provided to sample")
	}

33
34
35
36
	tokens := make([]token, len(logits))
	for i := range logits {
		tokens[i].id = int32(i)
		tokens[i].value = logits[i]
37
38
	}

39
40
41
42
	t, err := s.sample(tokens)
	if err != nil {
		return -1, err
	}
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
	if s.grammar != nil {
		// optimization: first check if the max logit is accepted by the grammar
		// if the max logit is rejected, apply the grammar to all logits (slower)
		top := []token{t}
		s.grammar.Apply(top)
		if !math.IsInf(float64(top[0].value), -1) {
			s.grammar.Accept(top[0].id)
			return top[0].id, nil
		}

		// since .sample has side effects of modifying the tokens
		// we need to reset them before applying the grammar and
		// sampling again
		for i := range logits {
			tokens[i].id = int32(i)
			tokens[i].value = logits[i]
		}
		s.grammar.Apply(tokens)
		t, err = s.sample(tokens)
		if err != nil {
			return -1, err
		}
		s.grammar.Accept(t.id)
	}

	return t.id, nil
}

// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
	max := tokens[0]
	for i := 1; i < len(tokens); i++ {
		if tokens[i].value > max.value {
			max = tokens[i]
		}
	}

	return max
}

// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
	if s.temperature == 0 {
		return greedy(tokens), nil
89
90
	}

ParthSareen's avatar
ParthSareen committed
91
92
	// topK also sorts the tokens in descending order of logits
	tokens = topK(tokens, s.topK)
93

94
95
96
	// scale and normalize the tokens in place
	temperature(tokens, s.temperature)
	softmax(tokens)
97

98
99
100
101
102
103
104
105
	tokens = topP(tokens, s.topP)
	tokens = minP(tokens, s.minP)

	var r float32
	if s.rng != nil {
		r = s.rng.Float32()
	} else {
		r = rand.Float32()
106
107
	}

108
109
110
111
112
	// Calculate cumulative sum of probabilities
	var sum float32
	for i := range tokens {
		sum += tokens[i].value
		tokens[i].value = sum
113
	}
114
	r *= tokens[len(tokens)-1].value
115

116
	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
117
118
119
120
121
		if token.value < target {
			return -1
		}
		return 1
	})
122

123
124
125
	if math.IsNaN(float64(sum)) {
		return token{}, errors.New("sample: logits sum to NaN, check model output")
	}
126
	return tokens[idx], nil
127
128
129
}

// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
130
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
131
132
133
134
135
136
137
	var rng *rand.Rand
	if seed != -1 {
		// PCG requires two parameters: sequence and stream
		// Use original seed for sequence
		sequence := uint64(seed)
		// Use golden ratio hash to generate statistically independent seeds
		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
138
	}
139
140
141
	if temperature < 0.0 {
		temperature = 0.0
	}
142

143
144
	if topP < 0.0 {
		topP = 0.0
145
	}
146
147
	if topP >= 1.0 {
		topP = 1.0
148
149
	}

150
151
152
153
154
	if minP < 0.0 {
		minP = 0.0
	}
	if minP >= 1.0 {
		minP = 1.0
155
156
	}

157
	return Sampler{
158
159
160
161
162
		rng:         rng,
		topK:        topK,
		topP:        topP,
		minP:        minP,
		temperature: temperature,
163
164
165
166
		grammar:     grammar,
	}
}

167
168
type GrammarSampler struct {
	grammar *llama.Grammar
169
170
}

171
172
173
174
175
176
177
178
179
180
181
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
	vocabIds := make([]uint32, len(model.Vocabulary().Values))
	pieces := make([]string, len(model.Vocabulary().Values))
	for i := range model.Vocabulary().Values {
		pieces[i], _ = model.Decode([]int32{int32(i)})
		vocabIds[i] = uint32(i)
	}

	grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
	if grammar == nil {
		return nil, errors.New("sample: failed to initialize grammar")
182
183
	}

184
	return &GrammarSampler{grammar: grammar}, nil
185
186
}

187
func (g *GrammarSampler) Apply(tokens []token) {
188
189
	tds := make([]llama.TokenData, len(tokens))
	for i, token := range tokens {
190
		tds[i].ID = token.id
191
		tds[i].Logit = token.value
192
	}
193
	g.grammar.Apply(tds)
194
195
196
197
198
199

	for i := range tokens {
		tokens[i].value = tds[i].Logit
	}
}

200
201
func (g *GrammarSampler) Accept(token int32) {
	g.grammar.Accept(token)
202
203
}

204
205
func (g *GrammarSampler) Free() {
	g.grammar.Free()
206
}