Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
5c0b6639
Unverified
Commit
5c0b6639
authored
Mar 13, 2025
by
Parth Sareen
Committed by
GitHub
Mar 13, 2025
Browse files
sample: separate softmax and temperature transforms (#9732)
parent
4aeb67ef
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
98 additions
and
25 deletions
+98
-25
sample/samplers.go
sample/samplers.go
+1
-1
sample/transforms.go
sample/transforms.go
+14
-5
sample/transforms_test.go
sample/transforms_test.go
+83
-19
No files found.
sample/samplers.go
View file @
5c0b6639
...
@@ -87,8 +87,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
...
@@ -87,8 +87,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits
// topK also sorts the tokens in descending order of logits
tokens
=
topK
(
tokens
,
s
.
topK
)
tokens
=
topK
(
tokens
,
s
.
topK
)
// token logit values are updated to probabilities
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
softmax
(
tokens
)
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
...
...
sample/transforms.go
View file @
5c0b6639
...
@@ -25,8 +25,18 @@ func (h *tokenHeap) Pop() any {
...
@@ -25,8 +25,18 @@ func (h *tokenHeap) Pop() any {
return
x
return
x
}
}
// temperature applies scaling
and softmax
to the logits
// temperature applies scaling to the logits
func
temperature
(
ts
[]
token
,
temp
float32
)
[]
token
{
func
temperature
(
ts
[]
token
,
temp
float32
)
[]
token
{
// Ensure temperature clipping near 0 to avoid numerical instability
temp
=
max
(
temp
,
1e-7
)
for
i
:=
range
ts
{
ts
[
i
]
.
value
=
ts
[
i
]
.
value
/
temp
}
return
ts
}
// softmax applies normalization to the logits
func
softmax
(
ts
[]
token
)
[]
token
{
// Find max logit for numerical stability
// Find max logit for numerical stability
maxLogit
:=
float32
(
math
.
Inf
(
-
1
))
maxLogit
:=
float32
(
math
.
Inf
(
-
1
))
for
_
,
t
:=
range
ts
{
for
_
,
t
:=
range
ts
{
...
@@ -35,15 +45,14 @@ func temperature(ts []token, temp float32) []token {
...
@@ -35,15 +45,14 @@ func temperature(ts []token, temp float32) []token {
}
}
}
}
// Apply temperature and compute exp(x - max)
// Compute exp(x - max)
temp
=
max
(
temp
,
1e-7
)
var
sum
float32
var
sum
float32
for
i
,
v
:=
range
ts
{
for
i
,
v
:=
range
ts
{
ts
[
i
]
.
value
=
float32
(
math
.
Exp
(
float64
(
(
v
.
value
-
maxLogit
)
/
temp
)
))
ts
[
i
]
.
value
=
float32
(
math
.
Exp
(
float64
(
v
.
value
-
maxLogit
)))
sum
+=
ts
[
i
]
.
value
sum
+=
ts
[
i
]
.
value
}
}
//
Normalize
//
exp(x - max) / sum(exp(x - max))
for
i
:=
range
ts
{
for
i
:=
range
ts
{
ts
[
i
]
.
value
/=
sum
ts
[
i
]
.
value
/=
sum
}
}
...
...
sample/transforms_test.go
View file @
5c0b6639
...
@@ -32,27 +32,83 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
...
@@ -32,27 +32,83 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
}
}
}
}
func
TestTemperature
AndSoftmax
(
t
*
testing
.
T
)
{
func
TestTemperature
(
t
*
testing
.
T
)
{
input
:=
[]
float32
{
1
,
4
,
-
2
,
0
}
input
:=
[]
float32
{
1
.0
,
4
.0
,
-
2
.0
,
0.
0
}
got
:=
temperature
(
toTokens
(
input
),
0.5
)
got
:=
temperature
(
toTokens
(
input
),
0.5
)
want
:=
[]
float32
{
2.0
,
8.0
,
-
4.0
,
0.0
}
compareLogits
(
t
,
"temperature(0.5)"
,
want
,
got
)
// Check probabilities sum to 1
got
=
temperature
(
toTokens
(
input
),
1.0
)
var
sum
float32
want
=
[]
float32
{
1.0
,
4.0
,
-
2.0
,
0.0
}
for
_
,
token
:=
range
got
{
compareLogits
(
t
,
"temperature(1)"
,
want
,
got
)
sum
+=
token
.
value
}
got
=
temperature
(
toTokens
(
input
),
0.0
)
if
math
.
Abs
(
float64
(
sum
-
1.0
))
>
1e-6
{
want
=
[]
float32
{
1e7
,
4e7
,
-
2e7
,
0.0
}
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
compareLogits
(
t
,
"temperature(0)"
,
want
,
got
)
}
}
got
=
temperature
(
toTokens
(
input
),
1
)
func
TestSoftmax
(
t
*
testing
.
T
)
{
// Check probabilities sum to 1
tests
:=
[]
struct
{
sum
=
0.0
name
string
for
_
,
token
:=
range
got
{
input
[]
float32
sum
+=
token
.
value
expected
[]
float32
}{
{
name
:
"correctness softmax"
,
input
:
[]
float32
{
1
,
-
2
,
3
,
0
},
expected
:
[]
float32
{
0.113550
,
0.005653
,
0.839024
,
0.041773
},
},
{
name
:
"normal distribution"
,
input
:
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
},
},
{
name
:
"single value"
,
input
:
[]
float32
{
1.0
},
},
{
name
:
"identical values"
,
input
:
[]
float32
{
0.9
,
0.9
,
0.9
},
},
{
name
:
"large values"
,
input
:
[]
float32
{
1000.0
,
2000.0
,
3000.0
},
},
{
name
:
"small values"
,
input
:
[]
float32
{
1e-6
,
2e-6
,
3e-6
},
},
{
name
:
"negative values"
,
input
:
[]
float32
{
-
1.0
,
-
2.0
,
-
3.0
},
},
{
name
:
"mixed values"
,
input
:
[]
float32
{
-
100.0
,
0.0
,
100.0
},
},
}
}
if
math
.
Abs
(
float64
(
sum
-
1.0
))
>
1e-6
{
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
softmax
(
toTokens
(
tt
.
input
))
if
tt
.
expected
!=
nil
{
compareLogits
(
t
,
tt
.
name
,
tt
.
expected
,
got
)
return
}
// Check probabilities sum to 1
var
sum
float32
for
_
,
token
:=
range
got
{
sum
+=
token
.
value
if
token
.
value
<
0
||
token
.
value
>
1
{
t
.
Errorf
(
"probability out of range [0,1]: got %f"
,
token
.
value
)
}
}
if
math
.
Abs
(
float64
(
sum
-
1.0
))
>
1e-6
{
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
}
})
}
}
}
}
...
@@ -97,7 +153,7 @@ func TestTopP(t *testing.T) {
...
@@ -97,7 +153,7 @@ func TestTopP(t *testing.T) {
tokens
:=
toTokens
(
input
)
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax to get probabilities
// First apply temperature and softmax to get probabilities
tokens
=
temperature
(
tokens
,
1
)
tokens
=
softmax
(
tokens
)
tokens
=
topK
(
tokens
,
20
)
tokens
=
topK
(
tokens
,
20
)
// Then apply topP
// Then apply topP
...
@@ -115,7 +171,7 @@ func TestMinP(t *testing.T) {
...
@@ -115,7 +171,7 @@ func TestMinP(t *testing.T) {
tokens
:=
toTokens
(
input
)
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax
// First apply temperature and softmax
tokens
=
temperature
(
tokens
,
1
)
tokens
=
softmax
(
tokens
)
// Then apply minP
// Then apply minP
got
:=
minP
(
tokens
,
0.2
)
got
:=
minP
(
tokens
,
0.2
)
...
@@ -163,6 +219,14 @@ func BenchmarkTransforms(b *testing.B) {
...
@@ -163,6 +219,14 @@ func BenchmarkTransforms(b *testing.B) {
}
}
})
})
b
.
Run
(
"Softmax"
,
func
(
b
*
testing
.
B
)
{
b
.
ResetTimer
()
for
b
.
Loop
()
{
copy
(
tokensCopy
,
tokens
)
softmax
(
tokensCopy
)
}
})
b
.
Run
(
"TopK"
,
func
(
b
*
testing
.
B
)
{
b
.
Run
(
"TopK"
,
func
(
b
*
testing
.
B
)
{
b
.
ResetTimer
()
b
.
ResetTimer
()
for
b
.
Loop
()
{
for
b
.
Loop
()
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment