Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
e093db92
Unverified
Commit
e093db92
authored
Mar 10, 2025
by
Jeffrey Morgan
Committed by
GitHub
Mar 10, 2025
Browse files
sample: temporarily use grammars for constrained generation in new engine (#9586)
parent
a1cda80b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
298 additions
and
210 deletions
+298
-210
llama/llama.go
llama/llama.go
+68
-0
llama/sampling_ext.cpp
llama/sampling_ext.cpp
+22
-0
llama/sampling_ext.h
llama/sampling_ext.h
+3
-0
llm/server.go
llm/server.go
+16
-21
runner/ollamarunner/runner.go
runner/ollamarunner/runner.go
+19
-4
sample/samplers.go
sample/samplers.go
+135
-54
sample/samplers_benchmark_test.go
sample/samplers_benchmark_test.go
+8
-20
sample/samplers_test.go
sample/samplers_test.go
+5
-89
sample/transforms.go
sample/transforms.go
+8
-8
sample/transforms_test.go
sample/transforms_test.go
+14
-14
No files found.
llama/llama.go
View file @
e093db92
...
...
@@ -245,6 +245,20 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
return
&
m
,
nil
}
func
LoadVocabFromFile
(
path
string
)
(
*
Vocab
,
error
)
{
mp
:=
C
.
CString
(
path
)
defer
C
.
free
(
unsafe
.
Pointer
(
mp
))
v
:=
Vocab
{
c
:
C
.
llama_load_vocab_from_file
(
mp
)}
if
v
.
c
==
nil
{
return
nil
,
fmt
.
Errorf
(
"unable to load vocab: %s"
,
path
)
}
return
&
v
,
nil
}
func
FreeVocab
(
vocab
*
Vocab
)
{
C
.
llama_free_vocab
(
vocab
.
c
)
}
func
FreeModel
(
model
*
Model
)
{
C
.
llama_model_free
(
model
.
c
)
}
...
...
@@ -293,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
return
nil
}
type
Vocab
struct
{
c
*
C
.
struct_llama_vocab
}
func
(
m
*
Model
)
Vocab
()
*
C
.
struct_llama_vocab
{
return
C
.
llama_model_get_vocab
(
m
.
c
)
}
...
...
@@ -669,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte {
}
return
buf
[
:
n
]
}
type
Sampler
struct
{
c
*
C
.
struct_llama_sampler
}
func
NewGrammarSampler
(
vocab
*
Vocab
,
grammar
string
)
*
Sampler
{
cGrammar
:=
C
.
CString
(
grammar
)
cRoot
:=
C
.
CString
(
"root"
)
defer
C
.
free
(
unsafe
.
Pointer
(
cGrammar
))
defer
C
.
free
(
unsafe
.
Pointer
(
cRoot
))
sampler
:=
&
Sampler
{
c
:
C
.
llama_sampler_init_grammar
(
vocab
.
c
,
cGrammar
,
cRoot
)}
return
sampler
}
func
(
s
*
Sampler
)
Accept
(
token
int32
)
{
C
.
llama_sampler_accept
(
s
.
c
,
C
.
llama_token
(
token
))
}
type
TokenData
struct
{
Id
int32
Logit
float32
}
func
(
s
*
Sampler
)
Apply
(
tokens
[]
TokenData
)
{
tds
:=
make
([]
C
.
struct_llama_token_data
,
len
(
tokens
))
for
i
,
token
:=
range
tokens
{
tds
[
i
]
=
C
.
struct_llama_token_data
{
id
:
C
.
int32_t
(
token
.
Id
),
logit
:
C
.
float
(
token
.
Logit
),
p
:
C
.
float
(
0.0
),
}
}
tda
:=
&
C
.
llama_token_data_array
{
data
:
(
*
C
.
struct_llama_token_data
)(
unsafe
.
Pointer
(
&
tds
[
0
])),
size
:
C
.
size_t
(
len
(
tokens
)),
selected
:
C
.
int64_t
(
-
1
),
sorted
:
C
.
bool
(
false
),
}
var
pinner
runtime
.
Pinner
pinner
.
Pin
(
&
tds
[
0
])
defer
pinner
.
Unpin
()
C
.
llama_sampler_apply
(
s
.
c
,
tda
)
for
i
:=
range
tokens
{
tokens
[
i
]
.
Logit
=
float32
(
tds
[
i
]
.
logit
)
}
}
llama/sampling_ext.cpp
View file @
e093db92
...
...
@@ -2,6 +2,9 @@
#include "sampling.h"
#include "sampling_ext.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "llama-model.h"
#include "llama-model-loader.h"
struct
common_sampler
*
common_sampler_cinit
(
const
struct
llama_model
*
model
,
struct
common_sampler_cparams
*
params
)
{
try
{
...
...
@@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
return
0
;
}
}
struct
llama_vocab
*
llama_load_vocab_from_file
(
const
char
*
fname
)
{
llama_vocab
*
vocab
=
new
llama_vocab
();
try
{
const
auto
kv
=
LLM_KV
(
LLM_ARCH_UNKNOWN
);
std
::
vector
<
std
::
string
>
splits
=
{};
llama_model_loader
ml
(
std
::
string
(
fname
),
splits
,
false
,
false
,
nullptr
);
vocab
->
load
(
ml
,
kv
);
}
catch
(
const
std
::
exception
&
err
)
{
LLAMA_LOG_ERROR
(
"%s: error loading model: %s
\n
"
,
__func__
,
err
.
what
());
return
nullptr
;
}
return
vocab
;
}
void
llama_free_vocab
(
struct
llama_vocab
*
vocab
)
{
delete
vocab
;
}
llama/sampling_ext.h
View file @
e093db92
...
...
@@ -35,6 +35,9 @@ extern "C"
int
schema_to_grammar
(
const
char
*
json_schema
,
char
*
grammar
,
size_t
max_len
);
struct
llama_vocab
*
llama_load_vocab_from_file
(
const
char
*
fname
);
void
llama_free_vocab
(
struct
llama_vocab
*
vocab
);
#ifdef __cplusplus
}
#endif
...
...
llm/server.go
View file @
e093db92
...
...
@@ -729,29 +729,24 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
if
len
(
req
.
Format
)
>
0
{
format
:=
string
(
req
.
Format
)
if
format
!=
`null`
&&
format
!=
`""`
{
if
s
.
textProcessor
!=
nil
{
// New engine handles this on the backend
request
[
"format"
]
=
req
.
Format
}
else
{
// old engine
switch
format
{
case
`"json"`
:
request
[
"grammar"
]
=
grammarJSON
default
:
if
req
.
Format
[
0
]
!=
'{'
{
return
fmt
.
Errorf
(
"invalid format: %q; expected
\"
json
\"
or a valid JSON Schema object"
,
req
.
Format
)
}
switch
string
(
req
.
Format
)
{
case
`null`
,
`""`
:
// Field was set, but "missing" a value. We accept
// these as "not set".
break
case
`"json"`
:
request
[
"grammar"
]
=
grammarJSON
default
:
if
req
.
Format
[
0
]
!=
'{'
{
return
fmt
.
Errorf
(
"invalid format: %q; expected
\"
json
\"
or a valid JSON Schema object"
,
req
.
Format
)
}
// User provided a JSON schema
g
:=
llama
.
SchemaToGrammar
(
req
.
Format
)
if
g
==
nil
{
return
fmt
.
Errorf
(
"invalid JSON schema in format"
)
}
request
[
"grammar"
]
=
string
(
g
)
}
// User provided a JSON schema
g
:=
llama
.
SchemaToGrammar
(
req
.
Format
)
if
g
==
nil
{
return
fmt
.
Errorf
(
"invalid JSON schema in format"
)
}
request
[
"grammar"
]
=
string
(
g
)
}
}
...
...
runner/ollamarunner/runner.go
View file @
e093db92
...
...
@@ -254,6 +254,12 @@ type Server struct {
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash
maphash
.
Hash
// vocab is a llama.cpp vocab required for gammar-based
// constrained generation (json mode, structured outputs)
// TODO: this is temporary until Ollama sampling supports
// constrained generation
vocab
*
sample
.
Vocab
}
func
(
s
*
Server
)
allNil
()
bool
{
...
...
@@ -574,18 +580,25 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var
grammar
*
sample
.
Grammar
var
err
error
if
req
.
Grammar
!=
""
{
grammar
,
err
=
sample
.
NewGrammar
(
s
.
vocab
,
req
.
Grammar
)
if
err
!=
nil
{
http
.
Error
(
w
,
"failed to load model vocabulary required for format"
,
http
.
StatusInternalServerError
)
return
}
}
sampler
:=
sample
.
NewSampler
(
req
.
Temperature
,
req
.
TopK
,
req
.
TopP
,
req
.
MinP
,
req
.
Seed
,
grammar
,
)
if
req
.
Grammar
!=
""
{
panic
(
"grammars are not yet supported"
)
}
seq
,
err
:=
s
.
NewSequence
(
req
.
Prompt
,
req
.
Images
,
NewSequenceParams
{
numPredict
:
req
.
NumPredict
,
stop
:
req
.
Stop
,
...
...
@@ -797,6 +810,8 @@ func (s *Server) loadModel(
panic
(
err
)
}
s
.
vocab
=
sample
.
NewVocab
(
mpath
)
// TODO(jessegross): LoRA loading
if
lpath
.
String
()
!=
""
{
panic
(
"loras are not yet implemented"
)
...
...
sample/samplers.go
View file @
e093db92
...
...
@@ -2,43 +2,88 @@ package sample
import
(
"errors"
"math"
"math/rand/v2"
"slices"
)
"sync"
// Sampler is not thread-safe. Each goroutine should have its own instance
type
Sampler
interface
{
Sample
([]
float32
)
(
int32
,
error
)
}
"github.com/ollama/ollama/llama"
)
//
logit
represents information about a single token during sampling
type
logit
struct
{
//
token
represents information about a single token during sampling
type
token
struct
{
id
int32
// The token's unique identifier
value
float32
// The raw logit or probability from the model
}
type
weighted
struct
{
type
Sampler
struct
{
rng
*
rand
.
Rand
tokens
[]
logit
topK
int
topP
float32
minP
float32
temperature
float32
grammar
*
Grammar
}
func
(
s
*
weighted
)
Sample
(
logits
[]
float32
)
(
int32
,
error
)
{
if
len
(
s
.
tokens
)
<
len
(
logits
)
{
s
.
tokens
=
make
([]
logit
,
len
(
logits
))
func
(
s
*
Sampler
)
Sample
(
logits
[]
float32
)
(
int32
,
error
)
{
tokens
:=
make
([]
token
,
len
(
logits
))
for
i
:=
range
logits
{
tokens
[
i
]
.
id
=
int32
(
i
)
tokens
[
i
]
.
value
=
logits
[
i
]
}
tokens
:=
s
.
tokens
[
:
len
(
logits
)]
t
,
err
:=
s
.
sample
(
tokens
)
if
err
!=
nil
{
return
-
1
,
err
}
for
i
,
v
:=
range
logits
{
tokens
[
i
]
.
id
=
int32
(
i
)
tokens
[
i
]
.
value
=
v
if
s
.
grammar
!=
nil
{
// optimization: first check if the max logit is accepted by the grammar
// if the max logit is rejected, apply the grammar to all logits (slower)
top
:=
[]
token
{
t
}
s
.
grammar
.
Apply
(
top
)
if
!
math
.
IsInf
(
float64
(
top
[
0
]
.
value
),
-
1
)
{
s
.
grammar
.
Accept
(
top
[
0
]
.
id
)
return
top
[
0
]
.
id
,
nil
}
// since .sample has side effects of modifying the tokens
// we need to reset them before applying the grammar and
// sampling again
for
i
:=
range
logits
{
tokens
[
i
]
.
id
=
int32
(
i
)
tokens
[
i
]
.
value
=
logits
[
i
]
}
s
.
grammar
.
Apply
(
tokens
)
t
,
err
=
s
.
sample
(
tokens
)
if
err
!=
nil
{
return
-
1
,
err
}
s
.
grammar
.
Accept
(
t
.
id
)
}
return
t
.
id
,
nil
}
// greedy returns the highest probability token from the tokens
func
greedy
(
tokens
[]
token
)
token
{
max
:=
tokens
[
0
]
for
i
:=
1
;
i
<
len
(
tokens
);
i
++
{
if
tokens
[
i
]
.
value
>
max
.
value
{
max
=
tokens
[
i
]
}
}
return
max
}
// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func
(
s
*
Sampler
)
sample
(
tokens
[]
token
)
(
token
,
error
)
{
if
s
.
temperature
==
0
{
return
greedy
(
tokens
),
nil
}
// Tokens are sorted by logits in TopK or SortTokens
if
s
.
topK
>
0
{
tokens
=
topK
(
tokens
,
s
.
topK
)
}
else
{
...
...
@@ -47,12 +92,14 @@ func (s *weighted) Sample(logits []float32) (int32, error) {
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
softmax
(
tokens
)
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if
len
(
tokens
)
==
0
{
return
-
1
,
errors
.
New
(
"no
valid logits found for weighted sampling
"
)
return
token
{}
,
errors
.
New
(
"no
tokens to sample from
"
)
}
var
r
float32
...
...
@@ -70,48 +117,18 @@ func (s *weighted) Sample(logits []float32) (int32, error) {
}
r
*=
tokens
[
len
(
tokens
)
-
1
]
.
value
idx
,
_
:=
slices
.
BinarySearchFunc
(
tokens
,
r
,
func
(
token
logit
,
target
float32
)
int
{
// Compare cumulative probabilities
idx
,
_
:=
slices
.
BinarySearchFunc
(
tokens
,
r
,
func
(
token
token
,
target
float32
)
int
{
if
token
.
value
<
target
{
return
-
1
}
// First token that exceeds target
return
1
})
if
idx
>=
len
(
tokens
)
{
idx
=
len
(
tokens
)
-
1
}
return
tokens
[
idx
]
.
id
,
nil
}
type
greedy
struct
{}
// Greedy sample returns the index of the maximum value in logits.
func
(
s
greedy
)
Sample
(
logits
[]
float32
)
(
int32
,
error
)
{
if
len
(
logits
)
==
0
{
return
-
1
,
errors
.
New
(
"no logits provided for greedy sampling"
)
}
maxIdx
:=
0
maxVal
:=
logits
[
0
]
for
i
:=
1
;
i
<
len
(
logits
);
i
++
{
if
logits
[
i
]
>
maxVal
{
maxVal
=
logits
[
i
]
maxIdx
=
i
}
}
return
int32
(
maxIdx
),
nil
return
tokens
[
idx
],
nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func
NewSampler
(
temperature
float32
,
topK
int
,
topP
float32
,
minP
float32
,
seed
int
)
Sampler
{
if
temperature
==
0
{
return
&
greedy
{}
}
func
NewSampler
(
temperature
float32
,
topK
int
,
topP
float32
,
minP
float32
,
seed
int
,
grammar
*
Grammar
)
Sampler
{
var
rng
*
rand
.
Rand
if
seed
!=
-
1
{
// PCG requires two parameters: sequence and stream
...
...
@@ -120,7 +137,9 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
// Use golden ratio hash to generate statistically independent seeds
rng
=
rand
.
New
(
rand
.
NewPCG
(
sequence
,
sequence
^
0x9E3779B9
))
}
temperature
=
max
(
temperature
,
1
)
if
temperature
<
0.0
{
temperature
=
0.0
}
if
topP
<
0.0
{
topP
=
0.0
...
...
@@ -136,11 +155,73 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
minP
=
1.0
}
return
&
weighted
{
return
Sampler
{
rng
:
rng
,
topK
:
topK
,
topP
:
topP
,
minP
:
minP
,
temperature
:
temperature
,
grammar
:
grammar
,
}
}
type
Grammar
struct
{
vocab
*
Vocab
grammar
string
sampler
*
llama
.
Sampler
}
func
NewGrammar
(
vocab
*
Vocab
,
grammar
string
)
(
*
Grammar
,
error
)
{
v
,
err
:=
vocab
.
Load
()
if
err
!=
nil
{
return
nil
,
err
}
return
&
Grammar
{
vocab
:
vocab
,
grammar
:
grammar
,
sampler
:
llama
.
NewGrammarSampler
(
v
,
grammar
),
},
nil
}
func
(
g
*
Grammar
)
Apply
(
tokens
[]
token
)
{
tds
:=
make
([]
llama
.
TokenData
,
len
(
tokens
))
for
i
,
token
:=
range
tokens
{
tds
[
i
]
.
Id
=
token
.
id
tds
[
i
]
.
Logit
=
token
.
value
}
g
.
sampler
.
Apply
(
tds
)
for
i
:=
range
tokens
{
tokens
[
i
]
.
value
=
tds
[
i
]
.
Logit
}
}
func
(
g
*
Grammar
)
Accept
(
token
int32
)
{
g
.
sampler
.
Accept
(
token
)
}
type
Vocab
struct
{
once
sync
.
Once
vocab
*
llama
.
Vocab
err
error
path
string
}
func
NewVocab
(
path
string
)
*
Vocab
{
return
&
Vocab
{
path
:
path
}
}
// Load returns the lazily-loaded vocabulary
func
(
v
*
Vocab
)
Load
()
(
*
llama
.
Vocab
,
error
)
{
v
.
once
.
Do
(
func
()
{
vocab
,
err
:=
llama
.
LoadVocabFromFile
(
v
.
path
)
if
err
!=
nil
{
v
.
err
=
err
return
}
v
.
vocab
=
vocab
})
return
v
.
vocab
,
v
.
err
}
sample/samplers_benchmark_test.go
View file @
e093db92
...
...
@@ -16,13 +16,10 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits
[
i
]
=
float32
(
rand
.
Float64
()
*
10
-
5
)
}
sampler
:=
NewSampler
(
0.8
,
0
,
0
,
0
,
42
)
sampler
:=
NewSampler
(
0.8
,
0
,
0
,
0
,
42
,
nil
)
b
.
ResetTimer
()
for
b
.
Loop
()
{
_
,
err
:=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
b
.
Fatalf
(
"Sampling failed: %v"
,
err
)
}
sampler
.
Sample
(
logits
)
}
})
}
...
...
@@ -52,30 +49,24 @@ func BenchmarkWeightedSampler(b *testing.B) {
for
_
,
tc
:=
range
configs
{
b
.
Run
(
"Config"
+
tc
.
name
,
func
(
b
*
testing
.
B
)
{
sampler
:=
NewSampler
(
tc
.
temperature
,
tc
.
topK
,
tc
.
topP
,
tc
.
minP
,
tc
.
seed
)
sampler
:=
NewSampler
(
tc
.
temperature
,
tc
.
topK
,
tc
.
topP
,
tc
.
minP
,
tc
.
seed
,
nil
)
sampler
.
Sample
(
logits
)
b
.
ResetTimer
()
for
b
.
Loop
()
{
_
,
err
:=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
b
.
Fatalf
(
"Sampling failed: %v"
,
err
)
}
sampler
.
Sample
(
logits
)
}
})
}
// Test with combined transforms separately - topK influences performance greatly
b
.
Run
(
"TransformCombined"
,
func
(
b
*
testing
.
B
)
{
sampler
:=
NewSampler
(
0.8
,
50
,
0.9
,
0.05
,
42
)
sampler
:=
NewSampler
(
0.8
,
50
,
0.9
,
0.05
,
42
,
nil
)
b
.
ResetTimer
()
for
b
.
Loop
()
{
_
,
err
:=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
b
.
Fatalf
(
"Sampling failed: %v"
,
err
)
}
sampler
.
Sample
(
logits
)
}
})
}
...
...
@@ -90,14 +81,11 @@ func BenchmarkGreedySampler(b *testing.B) {
logits
[
i
]
=
float32
(
rand
.
Float64
()
*
10
-
5
)
}
sampler
:=
NewSampler
(
0
,
-
1
,
0
,
0
,
-
1
)
sampler
:=
NewSampler
(
0
,
-
1
,
0
,
0
,
-
1
,
nil
)
b
.
ResetTimer
()
for
b
.
Loop
()
{
_
,
err
:=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
b
.
Fatalf
(
"Sampling failed: %v"
,
err
)
}
sampler
.
Sample
(
logits
)
}
})
}
...
...
sample/samplers_test.go
View file @
e093db92
...
...
@@ -7,7 +7,7 @@ import (
func
TestWeighted
(
t
*
testing
.
T
)
{
logits
:=
[]
float32
{
-
10
,
3
,
-
10
,
-
10
}
sampler
:=
NewSampler
(
0
,
0
,
0
,
0
,
0
)
sampler
:=
NewSampler
(
0
,
0
,
0
,
0
,
0
,
nil
)
got
,
err
:=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
t
.
Error
(
err
)
...
...
@@ -19,7 +19,7 @@ func TestWeighted(t *testing.T) {
}
logits
=
[]
float32
{
-
100
,
-
10
,
0
,
10
}
sampler
=
NewSampler
(
0
,
0
,
0
,
0
,
0
)
sampler
=
NewSampler
(
0
,
0
,
0
,
0
,
0
,
nil
)
got
,
err
=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
t
.
Error
(
err
)
...
...
@@ -31,94 +31,10 @@ func TestWeighted(t *testing.T) {
}
}
func
TestNewSampler
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
temperature
float32
topK
int
topP
float32
minP
float32
seed
int
wantGreedy
bool
// Instead of wantErr, check if we get greedy sampler
}{
{
name
:
"temperature"
,
temperature
:
0.5
,
wantGreedy
:
false
,
},
{
name
:
"zero temperature - greedy"
,
temperature
:
0
,
wantGreedy
:
true
,
},
{
name
:
"top k"
,
temperature
:
0.1
,
topK
:
10
,
wantGreedy
:
false
,
},
{
name
:
"top p"
,
temperature
:
0.1
,
topP
:
0.9
,
wantGreedy
:
false
,
},
{
name
:
"min p"
,
temperature
:
0.1
,
minP
:
0.2
,
wantGreedy
:
false
,
},
{
name
:
"seed - weighted"
,
temperature
:
0.1
,
seed
:
42
,
wantGreedy
:
false
,
},
{
name
:
"default values"
,
temperature
:
0.8
,
topK
:
40
,
topP
:
0.9
,
minP
:
0.0
,
seed
:
0
,
wantGreedy
:
false
,
},
{
name
:
"all zeroes - greedy"
,
temperature
:
0.0
,
topK
:
0
,
topP
:
0.0
,
minP
:
0.0
,
seed
:
0
,
wantGreedy
:
true
,
},
{
name
:
"all transforms"
,
temperature
:
0.8
,
topK
:
50
,
topP
:
0.95
,
minP
:
0.1
,
seed
:
42
,
wantGreedy
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
sampler
:=
NewSampler
(
tt
.
temperature
,
tt
.
topK
,
tt
.
topP
,
tt
.
minP
,
tt
.
seed
)
_
,
isGreedy
:=
sampler
.
(
*
greedy
)
if
isGreedy
!=
tt
.
wantGreedy
{
t
.
Errorf
(
"NewSampler() got greedy = %v, want %v"
,
isGreedy
,
tt
.
wantGreedy
)
}
})
}
}
func
BenchmarkSample
(
b
*
testing
.
B
)
{
weighted
:=
NewSampler
(
0.5
,
10
,
0.9
,
0.2
,
-
1
)
samplers
:=
map
[
string
]
Sampler
{
"Greedy"
:
NewSampler
(
0
,
0
,
0
,
0
,
0
),
// Use NewSampler with temp=0 for greedy
"Weighted"
:
weighted
,
"Greedy"
:
NewSampler
(
0
,
0
,
0
,
0
,
0
,
nil
),
// Use NewSampler with temp=0 for greedy
"Weighted"
:
NewSampler
(
0.5
,
10
,
0.9
,
0.2
,
-
1
,
nil
)
,
}
// Generate random logits for benchmarking
...
...
@@ -132,7 +48,7 @@ func BenchmarkSample(b *testing.B) {
b
.
ResetTimer
()
for
b
.
Loop
()
{
if
_
,
err
:=
s
.
Sample
(
logits
);
err
!=
nil
{
b
.
Error
(
err
)
b
.
Fatalf
(
"error sampling: %v"
,
err
)
}
}
})
...
...
sample/transforms.go
View file @
e093db92
...
...
@@ -5,7 +5,7 @@ import (
"slices"
)
func
softmax
(
ts
[]
logit
)
[]
logit
{
func
softmax
(
ts
[]
token
)
[]
token
{
var
sum
float32
for
i
,
v
:=
range
ts
{
ts
[
i
]
.
value
=
float32
(
math
.
Exp
(
float64
(
v
.
value
)))
...
...
@@ -19,7 +19,7 @@ func softmax(ts []logit) []logit {
return
ts
}
func
temperature
(
ti
[]
logit
,
t
float32
)
[]
logit
{
func
temperature
(
ti
[]
token
,
t
float32
)
[]
token
{
if
t
==
1
{
return
ti
}
...
...
@@ -51,7 +51,7 @@ func temperature(ti []logit, t float32) []logit {
// 1. Finds the smallest value between the node and its children
// 2. If the node is not the smallest, swaps it with its smallest child
// 3. Continues this process down the affected path until the min-heap property is restored
func
siftDown
(
data
[]
logit
,
start
,
end
int
)
{
func
siftDown
(
data
[]
token
,
start
,
end
int
)
{
root
:=
start
for
{
child
:=
2
*
root
+
1
...
...
@@ -73,7 +73,7 @@ func siftDown(data []logit, start, end int) {
}
// topK limits the number of tokens considered to the k highest logits
func
topK
(
ts
[]
logit
,
k
int
)
[]
logit
{
func
topK
(
ts
[]
token
,
k
int
)
[]
token
{
if
k
>=
len
(
ts
)
{
return
ts
}
...
...
@@ -99,7 +99,7 @@ func topK(ts []logit, k int) []logit {
}
// topP limits tokens to those with cumulative probability p
func
topP
(
ts
[]
logit
,
p
float32
)
[]
logit
{
func
topP
(
ts
[]
token
,
p
float32
)
[]
token
{
if
p
==
1.0
{
return
ts
}
...
...
@@ -118,7 +118,7 @@ func topP(ts []logit, p float32) []logit {
}
// minP limits tokens to those with cumulative probability p
func
minP
(
ts
[]
logit
,
p
float32
)
[]
logit
{
func
minP
(
ts
[]
token
,
p
float32
)
[]
token
{
if
p
==
1.0
{
return
ts
}
...
...
@@ -146,7 +146,7 @@ func minP(ts []logit, p float32) []logit {
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
func
sortLogits
(
tokens
[]
logit
)
{
func
sortLogits
(
tokens
[]
token
)
{
if
len
(
tokens
)
<=
1
{
return
}
...
...
@@ -187,7 +187,7 @@ func sortLogits(tokens []logit) {
}
// Second pass: place elements in correct position
output
:=
make
([]
logit
,
len
(
tokens
))
output
:=
make
([]
token
,
len
(
tokens
))
// Track current positions
countsCopy
:=
counts
...
...
sample/transforms_test.go
View file @
e093db92
...
...
@@ -7,10 +7,10 @@ import (
)
// Helper to convert float64 slice to logit slice
func
to
Logit
s
(
values
[]
float64
)
[]
logit
{
tokens
:=
make
([]
logit
,
len
(
values
))
func
to
Token
s
(
values
[]
float64
)
[]
token
{
tokens
:=
make
([]
token
,
len
(
values
))
for
i
,
v
:=
range
values
{
tokens
[
i
]
=
logit
{
tokens
[
i
]
=
token
{
id
:
int32
(
i
),
value
:
float32
(
v
),
}
...
...
@@ -19,7 +19,7 @@ func toLogits(values []float64) []logit {
}
// Helper to compare logit slices
func
compareLogits
(
t
*
testing
.
T
,
name
string
,
want
[]
float64
,
got
[]
logit
)
{
func
compareLogits
(
t
*
testing
.
T
,
name
string
,
want
[]
float64
,
got
[]
token
)
{
t
.
Helper
()
if
len
(
want
)
!=
len
(
got
)
{
t
.
Errorf
(
"%s: length mismatch: want %d, got %d"
,
name
,
len
(
want
),
len
(
got
))
...
...
@@ -36,13 +36,13 @@ func TestTemperature(t *testing.T) {
input
:=
[]
float64
{
2
,
-
1
,
4
,
-
3
,
1
,
-
2
,
0
}
want
:=
[]
float64
{
-
4
,
-
10
,
0
,
-
14
,
-
6
,
-
12
,
-
8
}
// (logit - max logit) / temp
got
:=
temperature
(
to
Logit
s
(
input
),
0.5
)
got
:=
temperature
(
to
Token
s
(
input
),
0.5
)
compareLogits
(
t
,
"Temperature"
,
want
,
got
)
}
func
TestSoftmax
(
t
*
testing
.
T
)
{
input
:=
[]
float64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
got
:=
softmax
(
to
Logit
s
(
input
))
got
:=
softmax
(
to
Token
s
(
input
))
// Check probabilities sum to 1
var
sum
float32
...
...
@@ -65,7 +65,7 @@ func TestTopK(t *testing.T) {
input
:=
[]
float64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
// Test k=3
got
:=
topK
(
to
Logit
s
(
input
),
3
)
got
:=
topK
(
to
Token
s
(
input
),
3
)
if
len
(
got
)
!=
3
{
t
.
Errorf
(
"topK(3): wrong length: want 3, got %d"
,
len
(
got
))
}
...
...
@@ -74,13 +74,13 @@ func TestTopK(t *testing.T) {
compareLogits
(
t
,
"topK(3)"
,
want
,
got
)
// Test k > len
got
=
topK
(
to
Logit
s
(
input
),
10
)
got
=
topK
(
to
Token
s
(
input
),
10
)
compareLogits
(
t
,
"topK(10)"
,
input
,
got
)
}
func
TestTopP
(
t
*
testing
.
T
)
{
input
:=
[]
float64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
tokens
:=
to
Logit
s
(
input
)
tokens
:=
to
Token
s
(
input
)
// First apply temperature and softmax to get probabilities
tokens
=
temperature
(
tokens
,
1
)
...
...
@@ -99,7 +99,7 @@ func TestTopP(t *testing.T) {
func
TestMinP
(
t
*
testing
.
T
)
{
input
:=
[]
float64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
,
3
}
tokens
:=
to
Logit
s
(
input
)
tokens
:=
to
Token
s
(
input
)
// First apply temperature and softmax
tokens
=
temperature
(
tokens
,
1
)
...
...
@@ -116,7 +116,7 @@ func TestMinP(t *testing.T) {
func
TestSortLogits
(
t
*
testing
.
T
)
{
input
:=
[]
float64
{
3
,
1
,
4
,
2
,
-
1
,
0
,
-
2
}
tokens
:=
to
Logit
s
(
input
)
tokens
:=
to
Token
s
(
input
)
sortLogits
(
tokens
)
...
...
@@ -133,15 +133,15 @@ func TestSortLogits(t *testing.T) {
func
BenchmarkTransforms
(
b
*
testing
.
B
)
{
// Generate random logits
tokens
:=
make
([]
logit
,
1
<<
16
)
tokens
:=
make
([]
token
,
1
<<
16
)
for
i
:=
range
tokens
{
tokens
[
i
]
=
logit
{
tokens
[
i
]
=
token
{
id
:
int32
(
i
),
value
:
rand
.
Float32
(),
}
}
tokensCopy
:=
make
([]
logit
,
len
(
tokens
))
tokensCopy
:=
make
([]
token
,
len
(
tokens
))
b
.
Run
(
"Temperature"
,
func
(
b
*
testing
.
B
)
{
b
.
ResetTimer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment