Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
108fe021
Unverified
Commit
108fe021
authored
Mar 17, 2025
by
Parth Sareen
Committed by
GitHub
Mar 17, 2025
Browse files
sample: make mutations in transforms explicit (#9743)
* updated minP to use early exit making use of sorted tokens
parent
50b59620
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
109 additions
and
71 deletions
+109
-71
sample/samplers.go
sample/samplers.go
+3
-2
sample/transforms.go
sample/transforms.go
+11
-26
sample/transforms_test.go
sample/transforms_test.go
+95
-43
No files found.
sample/samplers.go
View file @
108fe021
...
...
@@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits
tokens
=
topK
(
tokens
,
s
.
topK
)
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
softmax
(
tokens
)
// scale and normalize the tokens in place
temperature
(
tokens
,
s
.
temperature
)
softmax
(
tokens
)
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
...
...
sample/transforms.go
View file @
108fe021
...
...
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
}
// temperature applies scaling to the logits
func
temperature
(
ts
[]
token
,
temp
float32
)
[]
token
{
func
temperature
(
ts
[]
token
,
temp
float32
)
{
// Ensure temperature clipping near 0 to avoid numerical instability
temp
=
max
(
temp
,
1e-7
)
for
i
:=
range
ts
{
ts
[
i
]
.
value
=
ts
[
i
]
.
value
/
temp
}
return
ts
}
// softmax applies normalization to the logits
func
softmax
(
ts
[]
token
)
[]
token
{
func
softmax
(
ts
[]
token
)
{
// Find max logit for numerical stability
maxLogit
:=
float32
(
math
.
Inf
(
-
1
))
for
_
,
t
:=
range
ts
{
...
...
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
for
i
:=
range
ts
{
ts
[
i
]
.
value
/=
sum
}
return
ts
}
// topK limits the number of tokens considered to the k highest logits
...
...
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
}
// topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func
topP
(
ts
[]
token
,
p
float32
)
[]
token
{
if
p
==
1.0
{
return
ts
...
...
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
for
i
,
t
:=
range
ts
{
sum
+=
t
.
value
if
sum
>
float32
(
p
)
{
ts
=
ts
[
:
i
+
1
]
return
ts
return
ts
[
:
i
+
1
]
}
}
return
ts
}
// minP limits tokens to those with cumulative probability p
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
func
minP
(
ts
[]
token
,
p
float32
)
[]
token
{
if
p
==
1.0
{
return
ts
}
maxProb
:=
float32
(
math
.
Inf
(
-
1
))
for
_
,
token
:=
range
ts
{
if
token
.
value
>
maxProb
{
maxProb
=
token
.
value
}
}
maxProb
:=
ts
[
0
]
.
value
threshold
:=
maxProb
*
float32
(
p
)
threshold
:=
maxProb
*
p
// Filter tokens in-place
validTokens
:=
ts
[
:
0
]
for
i
,
token
:=
range
ts
{
if
token
.
value
>=
threshold
{
validTokens
=
append
(
validTokens
,
ts
[
i
])
for
i
,
t
:=
range
ts
{
if
t
.
value
<
threshold
{
return
ts
[
:
i
]
}
}
ts
=
validTokens
return
ts
}
sample/transforms_test.go
View file @
108fe021
...
...
@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func
TestTemperature
(
t
*
testing
.
T
)
{
input
:=
[]
float32
{
1.0
,
4.0
,
-
2.0
,
0.0
}
got
:=
temperature
(
toTokens
(
input
),
0.5
)
tokens
:=
toTokens
(
input
)
temperature
(
tokens
,
0.5
)
want
:=
[]
float32
{
2.0
,
8.0
,
-
4.0
,
0.0
}
compareLogits
(
t
,
"temperature(0.5)"
,
want
,
got
)
compareLogits
(
t
,
"temperature(0.5)"
,
want
,
tokens
)
got
=
temperature
(
toTokens
(
input
),
1.0
)
input
=
[]
float32
{
1.0
,
4.0
,
-
2.0
,
0.0
}
tokens
=
toTokens
(
input
)
temperature
(
tokens
,
1.0
)
want
=
[]
float32
{
1.0
,
4.0
,
-
2.0
,
0.0
}
compareLogits
(
t
,
"temperature(1)"
,
want
,
got
)
compareLogits
(
t
,
"temperature(1)"
,
want
,
tokens
)
got
=
temperature
(
toTokens
(
input
),
0.0
)
input
=
[]
float32
{
1.0
,
4.0
,
-
2.0
,
0.0
}
tokens
=
toTokens
(
input
)
temperature
(
tokens
,
0.0
)
want
=
[]
float32
{
1e7
,
4e7
,
-
2e7
,
0.0
}
compareLogits
(
t
,
"temperature(0)"
,
want
,
got
)
compareLogits
(
t
,
"temperature(0)"
,
want
,
tokens
)
}
func
TestSoftmax
(
t
*
testing
.
T
)
{
...
...
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
softmax
(
toTokens
(
tt
.
input
))
tokens
:=
toTokens
(
tt
.
input
)
softmax
(
tokens
)
if
tt
.
expected
!=
nil
{
compareLogits
(
t
,
tt
.
name
,
tt
.
expected
,
got
)
compareLogits
(
t
,
tt
.
name
,
tt
.
expected
,
tokens
)
return
}
// Check probabilities sum to 1
var
sum
float32
for
_
,
token
:=
range
got
{
for
_
,
token
:=
range
tokens
{
sum
+=
token
.
value
if
token
.
value
<
0
||
token
.
value
>
1
{
t
.
Errorf
(
"probability out of range [0,1]: got %f"
,
token
.
value
)
...
...
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
func
TestTopK
(
t
*
testing
.
T
)
{
input
:=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
// Test k=5
got
:=
topK
(
toTokens
(
input
),
5
)
if
len
(
got
)
!=
5
{
t
.
Errorf
(
"topK(5): wrong length: want 5, got %d"
,
len
(
got
))
tokens
:=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
5
)
if
len
(
tokens
)
!=
5
{
t
.
Errorf
(
"topK(5): wrong length: want 5, got %d"
,
len
(
tokens
))
}
// Should keep highest 3 values in descending order
want
:=
[]
float32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
}
compareLogits
(
t
,
"topK(3)"
,
want
,
got
)
compareLogits
(
t
,
"topK(3)"
,
want
,
tokens
)
got
=
topK
(
toTokens
(
input
),
20
)
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(20): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
tokens
=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
20
)
if
len
(
tokens
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(20): wrong length: want %d, got %d"
,
len
(
input
),
len
(
tokens
))
}
// Test k=-1
input
=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
want
=
[]
float32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
,
0.043722924
,
0.036774673
,
0.026986899
,
0.01681367
,
0.0046718004
,
0.00412893
,
0.0030491839
}
got
=
topK
(
toTokens
(
input
),
-
1
)
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(-1): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
tokens
=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
-
1
)
if
len
(
tokens
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(-1): wrong length: want %d, got %d"
,
len
(
input
),
len
(
tokens
))
}
compareLogits
(
t
,
"topK(-1)"
,
want
,
got
)
compareLogits
(
t
,
"topK(-1)"
,
want
,
tokens
)
// Test k=0
input
=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
want
=
[]
float32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
,
0.043722924
,
0.036774673
,
0.026986899
,
0.01681367
,
0.0046718004
,
0.00412893
,
0.0030491839
}
got
=
topK
(
toTokens
(
input
),
0
)
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(-1): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
tokens
=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
0
)
if
len
(
tokens
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(-1): wrong length: want %d, got %d"
,
len
(
input
),
len
(
tokens
))
}
compareLogits
(
t
,
"topK(-1)"
,
want
,
tokens
)
input
=
[]
float32
{
-
1e7
,
-
2e7
,
-
3e7
,
-
4e7
}
tokens
=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
1
)
if
len
(
tokens
)
<
1
{
t
.
Error
(
"topK should keep at least one token"
)
}
compareLogits
(
t
,
"topK(-1)"
,
want
,
got
)
}
func
TestTopP
(
t
*
testing
.
T
)
{
...
...
@@ -153,16 +165,25 @@ func TestTopP(t *testing.T) {
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax to get probabilities
tokens
=
softmax
(
tokens
)
softmax
(
tokens
)
tokens
=
topK
(
tokens
,
20
)
// Then apply topP
got
:
=
topP
(
tokens
,
0.95
)
tokens
=
topP
(
tokens
,
0.95
)
// Should keep tokens until cumsum > 0.95
if
len
(
got
)
>
3
{
t
.
Errorf
(
"topP(0.95): kept too many tokens: got %d"
,
len
(
got
))
t
.
Logf
(
"got: %v"
,
got
)
if
len
(
tokens
)
>
3
{
t
.
Errorf
(
"topP(0.95): kept too many tokens: got %d"
,
len
(
tokens
))
t
.
Logf
(
"got: %v"
,
tokens
)
}
// Test edge case - ensure at least one token remains
input
=
[]
float32
{
-
1e6
,
-
1e6
,
-
1e6
}
// One dominant token
tokens
=
toTokens
(
input
)
softmax
(
tokens
)
tokens
=
topP
(
tokens
,
0.0
)
// Very small p
if
len
(
tokens
)
<
1
{
t
.
Error
(
"topP should keep at least one token"
)
}
}
...
...
@@ -171,14 +192,45 @@ func TestMinP(t *testing.T) {
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax
tokens
=
softmax
(
tokens
)
tokens
=
topK
(
tokens
,
20
)
softmax
(
tokens
)
// Then apply minP
got
:=
minP
(
tokens
,
0.2
)
tokens
=
minP
(
tokens
,
1.0
)
if
len
(
tokens
)
!=
1
{
t
.
Errorf
(
"minP(1.0): should keep all tokens, got %d, want %d"
,
len
(
tokens
),
len
(
tokens
))
}
// Test with normal p value
tokens
=
toTokens
(
input
)
// Reset tokens
tokens
=
topK
(
tokens
,
20
)
softmax
(
tokens
)
tokens
=
minP
(
tokens
,
0.2
)
// Should keep tokens with prob >= 0.2 * max_prob
if
len
(
got
)
>
3
{
t
.
Errorf
(
"minP(0.2): kept too many tokens: got %d"
,
len
(
got
))
if
len
(
tokens
)
>
3
{
t
.
Errorf
(
"minP(0.2): kept too many tokens: got %d"
,
len
(
tokens
))
t
.
Logf
(
"got: %v"
,
tokens
)
}
// Test with zero p value
tokens
=
toTokens
(
input
)
// Reset tokens
tokens
=
topK
(
tokens
,
20
)
softmax
(
tokens
)
tokens
=
minP
(
tokens
,
0.0
)
// Should keep only the highest probability token
if
len
(
tokens
)
!=
len
(
input
)
{
t
.
Errorf
(
"minP(0.0): should keep only one token, got %d"
,
len
(
tokens
))
t
.
Logf
(
"got: %v"
,
tokens
)
}
input
=
[]
float32
{
1e-10
,
1e-10
,
1e-10
}
tokens
=
toTokens
(
input
)
softmax
(
tokens
)
tokens
=
minP
(
tokens
,
1.0
)
if
len
(
tokens
)
<
1
{
t
.
Error
(
"minP should keep at least one token even with extreme probabilities"
)
}
}
...
...
@@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) {
b
.
ResetTimer
()
for
b
.
Loop
()
{
copy
(
tokensCopy
,
tokens
)
topK
(
tokensCopy
,
10
)
tokens
=
topK
(
tokensCopy
,
10
)
}
})
...
...
@@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) {
b
.
ResetTimer
()
for
b
.
Loop
()
{
copy
(
tokensCopy
,
tokens
)
topP
(
tokensCopy
,
0.9
)
tokens
=
topP
(
tokensCopy
,
0.9
)
}
})
...
...
@@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) {
b
.
ResetTimer
()
for
b
.
Loop
()
{
copy
(
tokensCopy
,
tokens
)
minP
(
tokensCopy
,
0.2
)
tokens
=
minP
(
tokensCopy
,
0.2
)
}
})
...
...
@@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) {
b
.
ResetTimer
()
for
b
.
Loop
()
{
copy
(
tokensCopy
,
tokens
)
topK
(
tokensCopy
,
200000
)
tokens
=
topK
(
tokensCopy
,
200000
)
}
})
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment