Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
7e34f4fb
Unverified
Commit
7e34f4fb
authored
Mar 10, 2025
by
Parth Sareen
Committed by
GitHub
Mar 10, 2025
Browse files
sample: add numerical stability to temperature/softmax transform (#9631)
parent
fe776293
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
42 deletions
+28
-42
sample/samplers.go
sample/samplers.go
+2
-1
sample/transforms.go
sample/transforms.go
+16
-24
sample/transforms_test.go
sample/transforms_test.go
+10
-17
No files found.
sample/samplers.go
View file @
7e34f4fb
...
@@ -90,8 +90,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
...
@@ -90,8 +90,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
sortLogits
(
tokens
)
sortLogits
(
tokens
)
}
}
// token logit values are updated to probabilities
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
softmax
(
tokens
)
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
...
...
sample/transforms.go
View file @
7e34f4fb
...
@@ -5,13 +5,25 @@ import (
...
@@ -5,13 +5,25 @@ import (
"slices"
"slices"
)
)
func
softmax
(
ts
[]
token
)
[]
token
{
// temperature applies scaling and softmax to the logits
func
temperature
(
ts
[]
token
,
temp
float32
)
[]
token
{
// Find max logit for numerical stability
maxLogit
:=
float32
(
math
.
Inf
(
-
1
))
for
_
,
t
:=
range
ts
{
if
t
.
value
>
maxLogit
{
maxLogit
=
t
.
value
}
}
// Apply temperature and compute exp(x - max)
temp
=
max
(
temp
,
1e-7
)
var
sum
float32
var
sum
float32
for
i
,
v
:=
range
ts
{
for
i
,
v
:=
range
ts
{
ts
[
i
]
.
value
=
float32
(
math
.
Exp
(
float64
(
v
.
value
)))
ts
[
i
]
.
value
=
float32
(
math
.
Exp
(
float64
(
(
v
.
value
-
maxLogit
)
/
temp
)))
sum
+=
ts
[
i
]
.
value
sum
+=
ts
[
i
]
.
value
}
}
// Normalize
for
i
:=
range
ts
{
for
i
:=
range
ts
{
ts
[
i
]
.
value
/=
sum
ts
[
i
]
.
value
/=
sum
}
}
...
@@ -19,27 +31,6 @@ func softmax(ts []token) []token {
...
@@ -19,27 +31,6 @@ func softmax(ts []token) []token {
return
ts
return
ts
}
}
func
temperature
(
ti
[]
token
,
t
float32
)
[]
token
{
if
t
==
1
{
return
ti
}
temp
:=
max
(
t
,
1e-7
)
maxLogit
:=
float32
(
math
.
Inf
(
-
1
))
for
_
,
token
:=
range
ti
{
if
token
.
value
>
maxLogit
{
maxLogit
=
token
.
value
}
}
// subtracting max logit to avoid under/overflow
for
i
:=
range
ti
{
ti
[
i
]
.
value
=
(
ti
[
i
]
.
value
-
maxLogit
)
/
temp
}
return
ti
}
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
//
//
// The heap is represented as an array where for any node at index i:
// The heap is represented as an array where for any node at index i:
...
@@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token {
...
@@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token {
}
}
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
// sortLogits sorts implementation to sort tokens by logits using counting sort
// counting sort is faster than built-in sort for this use case
func
sortLogits
(
tokens
[]
token
)
{
func
sortLogits
(
tokens
[]
token
)
{
if
len
(
tokens
)
<=
1
{
if
len
(
tokens
)
<=
1
{
return
return
...
...
sample/transforms_test.go
View file @
7e34f4fb
...
@@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) {
...
@@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) {
}
}
}
}
func
TestTemperature
(
t
*
testing
.
T
)
{
func
TestTemperatureAndSoftmax
(
t
*
testing
.
T
)
{
input
:=
[]
float64
{
2
,
-
1
,
4
,
-
3
,
1
,
-
2
,
0
}
input
:=
[]
float64
{
1
,
4
,
-
2
,
0
}
want
:=
[]
float64
{
-
4
,
-
10
,
0
,
-
14
,
-
6
,
-
12
,
-
8
}
// (logit - max logit) / temp
got
:=
temperature
(
toTokens
(
input
),
0.5
)
got
:=
temperature
(
toTokens
(
input
),
0.5
)
compareLogits
(
t
,
"Temperature"
,
want
,
got
)
}
func
TestSoftmax
(
t
*
testing
.
T
)
{
input
:=
[]
float64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
got
:=
softmax
(
toTokens
(
input
))
// Check probabilities sum to 1
// Check probabilities sum to 1
var
sum
float32
var
sum
float32
...
@@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) {
...
@@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) {
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
}
}
// Check relative ordering is preserved
got
=
temperature
(
toTokens
(
input
),
1
)
for
i
:=
1
;
i
<
len
(
got
);
i
++
{
// Check probabilities sum to 1
if
got
[
i
]
.
value
<
got
[
i
-
1
]
.
value
{
sum
=
0.0
t
.
Errorf
(
"probability ordering not preserved at index %d"
,
i
)
for
_
,
token
:=
range
got
{
}
sum
+=
token
.
value
}
if
math
.
Abs
(
float64
(
sum
)
-
1.0
)
>
1e-6
{
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
}
}
}
}
...
@@ -84,7 +79,6 @@ func TestTopP(t *testing.T) {
...
@@ -84,7 +79,6 @@ func TestTopP(t *testing.T) {
// First apply temperature and softmax to get probabilities
// First apply temperature and softmax to get probabilities
tokens
=
temperature
(
tokens
,
1
)
tokens
=
temperature
(
tokens
,
1
)
tokens
=
softmax
(
tokens
)
sortLogits
(
tokens
)
sortLogits
(
tokens
)
// Then apply topP
// Then apply topP
...
@@ -103,7 +97,6 @@ func TestMinP(t *testing.T) {
...
@@ -103,7 +97,6 @@ func TestMinP(t *testing.T) {
// First apply temperature and softmax
// First apply temperature and softmax
tokens
=
temperature
(
tokens
,
1
)
tokens
=
temperature
(
tokens
,
1
)
tokens
=
softmax
(
tokens
)
// Then apply minP
// Then apply minP
got
:=
minP
(
tokens
,
0.2
)
got
:=
minP
(
tokens
,
0.2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment