Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
42a14f7f
Unverified
Commit
42a14f7f
authored
Mar 20, 2025
by
Parth Sareen
Committed by
GitHub
Mar 20, 2025
Browse files
sample: add error handling for empty logits (#9740)
parent
f8c3dbe5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
29 deletions
+97
-29
sample/samplers.go
sample/samplers.go
+7
-7
sample/samplers_test.go
sample/samplers_test.go
+24
-0
sample/transforms_test.go
sample/transforms_test.go
+66
-22
No files found.
sample/samplers.go
View file @
42a14f7f
...
...
@@ -26,6 +26,10 @@ type Sampler struct {
}
func
(
s
*
Sampler
)
Sample
(
logits
[]
float32
)
(
int32
,
error
)
{
if
len
(
logits
)
==
0
{
return
-
1
,
errors
.
New
(
"sample: no logits provided to sample"
)
}
tokens
:=
make
([]
token
,
len
(
logits
))
for
i
:=
range
logits
{
tokens
[
i
]
.
id
=
int32
(
i
)
...
...
@@ -94,13 +98,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
tokens
=
topP
(
tokens
,
s
.
topP
)
tokens
=
minP
(
tokens
,
s
.
minP
)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if
len
(
tokens
)
==
0
{
return
token
{},
errors
.
New
(
"no tokens to sample from"
)
}
var
r
float32
if
s
.
rng
!=
nil
{
r
=
s
.
rng
.
Float32
()
...
...
@@ -123,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return
1
})
if
math
.
IsNaN
(
float64
(
sum
))
{
return
token
{},
errors
.
New
(
"sample: logits sum to NaN, check model output"
)
}
return
tokens
[
idx
],
nil
}
...
...
sample/samplers_test.go
View file @
42a14f7f
package
sample
import
(
"math"
"math/rand/v2"
"testing"
)
...
...
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
if
want
!=
got
{
t
.
Errorf
(
"index mismatch: want %d, got %d"
,
want
,
got
)
}
// Test very high p
logits
=
[]
float32
{
1.0
,
0.9999999999999999
,
0.5
,
0.1
}
// Use extremely small topP to filter out all tokens
sampler
=
NewSampler
(
1.0
,
0
,
1e-10
,
0
,
0
,
nil
)
got
,
err
=
sampler
.
Sample
(
logits
)
if
err
!=
nil
{
t
.
Error
(
err
)
return
}
// Should get the token with the highest logit
want
=
int32
(
0
)
if
want
!=
got
{
t
.
Errorf
(
"index mismatch: want %d, got %d"
,
want
,
got
)
}
logits
=
[]
float32
{
float32
(
math
.
NaN
()),
float32
(
math
.
NaN
()),
float32
(
math
.
NaN
())}
sampler
=
NewSampler
(
1
,
0
,
0.95
,
0.05
,
0
,
nil
)
got
,
err
=
sampler
.
Sample
(
logits
)
if
err
==
nil
{
t
.
Errorf
(
"expected error, got %d"
,
got
)
return
}
}
func
BenchmarkSample
(
b
*
testing
.
B
)
{
...
...
sample/transforms_test.go
View file @
42a14f7f
...
...
@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
softmax
(
tokens
)
tokens
=
topK
(
tokens
,
20
)
// T
hen apply topP
tokens
=
topP
(
tokens
,
0.95
)
// T
est with very high p value
got
:
=
topP
(
tokens
,
1.0
)
// Should keep tokens until cumsum > 0.95
if
len
(
tokens
)
>
3
{
// Should keep all tokens since p is 1
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topP(1.0): should keep all tokens, got %d, want %d"
,
len
(
got
),
len
(
input
))
}
// Test with normal p value
got
=
topP
(
tokens
,
0.95
)
if
len
(
got
)
>
3
{
t
.
Errorf
(
"topP(0.95): kept too many tokens: got %d"
,
len
(
tokens
))
t
.
Logf
(
"got: %v"
,
tokens
)
t
.
Logf
(
"got: %v"
,
got
)
}
// Test edge case - ensure at least one token remains
input
=
[]
float32
{
-
1e6
,
-
1e6
,
-
1e
6
}
// One dominant token
input
=
[]
float32
{
-
1e6
,
-
1e6
,
-
1e
7
}
tokens
=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
20
)
softmax
(
tokens
)
tokens
=
topP
(
tokens
,
0.0
)
// Very small p
if
len
(
tokens
)
<
1
{
got
=
topP
(
tokens
,
0.0
)
if
len
(
got
)
<
1
{
t
.
Error
(
"topP should keep at least one token"
)
}
// Test with zero p value
got
=
topP
(
tokens
,
0.0
)
// Should keep only the highest probability token
if
len
(
got
)
!=
1
{
t
.
Errorf
(
"topP(0.0): should keep only one token, got %d"
,
len
(
got
))
t
.
Logf
(
"got: %v"
,
got
)
}
tokens
=
toTokens
(
input
)
tokens
=
topK
(
tokens
,
20
)
softmax
(
tokens
)
got
=
topP
(
tokens
,
1e-10
)
if
len
(
got
)
==
0
{
t
.
Errorf
(
"topP(1e-10): should keep at least one token, got %d"
,
len
(
got
))
t
.
Logf
(
"got: %v"
,
got
)
}
}
func
TestMinP
(
t
*
testing
.
T
)
{
input
:=
[]
float32
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
,
3
}
input
:=
[]
float32
{
-
2
,
0
,
-
1
,
-
3
,
2
,
1
,
4
,
3
}
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax
...
...
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
t
.
Logf
(
"got: %v"
,
tokens
)
}
// Test with single token
tokens
=
toTokens
(
input
[
:
1
])
tokens
=
topK
(
tokens
,
20
)
softmax
(
tokens
)
tokens
=
minP
(
tokens
,
0.1
)
// Should keep only the highest probability token
if
len
(
tokens
)
!=
1
{
t
.
Errorf
(
"minP(0.1): should return single token, got %d"
,
len
(
tokens
))
t
.
Logf
(
"got: %v"
,
tokens
)
}
input
=
[]
float32
{
1e-10
,
1e-10
,
1e-10
}
tokens
=
toTokens
(
input
)
softmax
(
tokens
)
tokens
=
minP
(
tokens
,
1.0
)
if
len
(
tokens
)
<
1
{
t
.
Error
(
"minP should keep at least one token even with extreme probabilities"
)
}
}
got
:=
minP
(
tokens
,
1.0
)
func
TestSortLogits
(
t
*
testing
.
T
)
{
input
:=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
tokens
:=
toTokens
(
input
)
if
len
(
got
)
!=
1
{
t
.
Errorf
(
"minP(1.0): should keep all tokens, got %d, want %d"
,
len
(
got
),
len
(
tokens
))
}
tokens
=
topK
(
tokens
,
20
)
// Test with normal p value
got
=
minP
(
tokens
,
0.2
)
for
i
:=
1
;
i
<
len
(
tokens
);
i
++
{
if
tokens
[
i
]
.
value
>
tokens
[
i
-
1
]
.
value
{
t
.
Errorf
(
"
sortLogits: tokens not sorted in descending order at index %d: %f > %f"
,
i
,
tokens
[
i
]
.
value
,
tokens
[
i
-
1
]
.
value
)
// Should keep tokens with prob >= 0.2 * max_prob
if
len
(
got
)
>
3
{
t
.
Errorf
(
"
minP(0.2): kept too many tokens: got %d"
,
len
(
got
))
t
.
Logf
(
"got: %v"
,
got
)
}
}
want
:=
[]
float32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
,
0.043722924
,
0.036774673
,
0.026986899
,
0.01681367
,
0.0046718004
,
0.00412893
,
0.0030491839
}
compareLogits
(
t
,
"sortLogits"
,
want
,
tokens
)
// Test with zero p value
got
=
minP
(
tokens
,
0.0
)
// Should keep only the highest probability token
if
len
(
got
)
!=
len
(
tokens
)
{
t
.
Errorf
(
"minP(0.0): should keep only one token, got %d"
,
len
(
got
))
t
.
Logf
(
"got: %v"
,
got
)
}
}
}
func
BenchmarkTransforms
(
b
*
testing
.
B
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment