Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
3ba91634
Commit
3ba91634
authored
Mar 12, 2025
by
ParthSareen
Committed by
Parth Sareen
Mar 12, 2025
Browse files
sample: simplify top_k=0 sorting
parent
1b7433b7
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
170 deletions
+11
-170
sample/testdata/logits.bin
sample/testdata/logits.bin
+0
-1
sample/transforms.go
sample/transforms.go
+11
-72
sample/transforms_test.go
sample/transforms_test.go
+0
-97
No files found.
sample/testdata/logits.bin
deleted
100644 → 0
View file @
1b7433b7
This diff is collapsed.
Click to expand it.
sample/transforms.go
View file @
3ba91634
...
@@ -10,7 +10,7 @@ import (
...
@@ -10,7 +10,7 @@ import (
type
tokenHeap
[]
token
type
tokenHeap
[]
token
func
(
h
tokenHeap
)
Len
()
int
{
return
len
(
h
)
}
func
(
h
tokenHeap
)
Len
()
int
{
return
len
(
h
)
}
func
(
h
tokenHeap
)
Less
(
i
,
j
int
)
bool
{
return
h
[
i
]
.
value
<
h
[
j
]
.
value
}
// Use < for min-heap to track largest elements
func
(
h
tokenHeap
)
Less
(
i
,
j
int
)
bool
{
return
h
[
i
]
.
value
<
h
[
j
]
.
value
}
func
(
h
tokenHeap
)
Swap
(
i
,
j
int
)
{
h
[
i
],
h
[
j
]
=
h
[
j
],
h
[
i
]
}
func
(
h
tokenHeap
)
Swap
(
i
,
j
int
)
{
h
[
i
],
h
[
j
]
=
h
[
j
],
h
[
i
]
}
func
(
h
*
tokenHeap
)
Push
(
x
any
)
{
func
(
h
*
tokenHeap
)
Push
(
x
any
)
{
...
@@ -72,7 +72,7 @@ func topK(ts []token, k int) []token {
...
@@ -72,7 +72,7 @@ func topK(ts []token, k int) []token {
}
}
// Convert heap to sorted slice in descending order
// Convert heap to sorted slice in descending order
result
:=
make
([]
token
,
k
)
result
:=
make
([]
token
,
len
(
h
)
)
for
i
:=
k
-
1
;
i
>=
0
;
i
--
{
for
i
:=
k
-
1
;
i
>=
0
;
i
--
{
result
[
i
]
=
heap
.
Pop
(
&
h
)
.
(
token
)
result
[
i
]
=
heap
.
Pop
(
&
h
)
.
(
token
)
}
}
...
@@ -126,77 +126,16 @@ func minP(ts []token, p float32) []token {
...
@@ -126,77 +126,16 @@ func minP(ts []token, p float32) []token {
return
ts
return
ts
}
}
// partialSortLogits uses quickselect to efficiently find and sort the top n tokens
// sortLogits sorts the tokens in descending order of logits
func
partialSortLogits
(
ts
[]
token
,
n
int
)
[]
token
{
if
n
>=
len
(
ts
)
{
n
=
len
(
ts
)
}
left
,
right
:=
0
,
len
(
ts
)
-
1
target
:=
n
-
1
// Quickselect algorithm to partition array around pivot
for
left
<
right
{
// Choose middle element as pivot and move it to the end
pivot
:=
left
+
(
right
-
left
)
/
2
ts
[
pivot
],
ts
[
right
]
=
ts
[
right
],
ts
[
pivot
]
// storeIndex tracks where to put next element greater than pivot
storeIndex
:=
left
pivotValue
:=
ts
[
right
]
.
value
// Partition array into elements >= pivot and < pivot
// Elements >= pivot go to the left side
for
i
:=
left
;
i
<
right
;
i
++
{
if
ts
[
i
]
.
value
>=
pivotValue
{
ts
[
storeIndex
],
ts
[
i
]
=
ts
[
i
],
ts
[
storeIndex
]
storeIndex
++
}
}
// Move pivot to its final position
ts
[
right
],
ts
[
storeIndex
]
=
ts
[
storeIndex
],
ts
[
right
]
// If pivot is at target position, we're done
// Otherwise recursively partition the half containing target
if
storeIndex
==
target
{
break
}
else
if
storeIndex
<
target
{
left
=
storeIndex
+
1
// Target is in right half
}
else
{
right
=
storeIndex
-
1
// Target is in left half
}
}
// Sort just the top n elements in descending order
slices
.
SortFunc
(
ts
[
:
n
],
func
(
a
,
b
token
)
int
{
if
a
.
value
>
b
.
value
{
return
-
1
}
if
a
.
value
<
b
.
value
{
return
1
}
return
0
})
return
ts
[
:
n
]
}
// sortLogits uses partialSortLogits to efficiently sort tokens
// It sorts approximately sqrt(len(tokens)) elements which balances
// between having enough tokens for sampling while avoiding full sort
func
sortLogits
(
ts
[]
token
)
{
func
sortLogits
(
ts
[]
token
)
{
// Use sqrt of token length as a heuristic for partial sort size
slices
.
SortFunc
(
ts
,
func
(
a
,
b
token
)
int
{
// This provides a good balance between performance and having enough tokens
n
:=
int
(
math
.
Sqrt
(
float64
(
len
(
ts
))))
+
1
// Ensure we have at least 100 tokens and at most 1000
switch
{
switch
{
case
n
<
100
:
case
a
.
value
<
b
.
value
:
n
=
100
return
1
case
n
>
1000
:
case
a
.
value
>
b
.
value
:
n
=
1000
return
-
1
default
:
return
0
}
}
})
partialSortLogits
(
ts
,
n
)
}
}
sample/transforms_test.go
View file @
3ba91634
package
sample
package
sample
import
(
import
(
"encoding/binary"
"errors"
"math"
"math"
"math/rand/v2"
"math/rand/v2"
"os"
"path/filepath"
"runtime"
"testing"
"testing"
)
)
...
@@ -130,98 +125,6 @@ func TestSortLogits(t *testing.T) {
...
@@ -130,98 +125,6 @@ func TestSortLogits(t *testing.T) {
compareLogits
(
t
,
"sortLogits"
,
want
,
tokens
)
compareLogits
(
t
,
"sortLogits"
,
want
,
tokens
)
}
}
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
func
TestSortLogitsWithRealData
(
t
*
testing
.
T
)
{
// This will be populated from testdata/logits.bin
// Format: 32-bit float array in binary format
logits
,
err
:=
loadTestLogits
(
t
)
if
err
!=
nil
{
t
.
Skipf
(
"Skipping real logit test: %v"
,
err
)
return
}
tokens
:=
toTokens
(
logits
)
sortLogits
(
tokens
)
// Calculate n for verification
n
:=
int
(
math
.
Sqrt
(
float64
(
len
(
tokens
))))
+
1
if
n
>
1000
{
n
=
1000
}
else
if
n
<
100
{
n
=
100
}
t
.
Logf
(
"Testing with %d tokens, partial sorting top %d"
,
len
(
tokens
),
n
)
// Only verify the top n elements are sorted (which is what we guarantee)
// This is much faster than checking the entire array
topN
:=
tokens
[
:
n
]
for
i
:=
1
;
i
<
len
(
topN
);
i
++
{
if
topN
[
i
]
.
value
>
topN
[
i
-
1
]
.
value
{
t
.
Fatalf
(
"top %d tokens not properly sorted at index %d: %.15f > %.15f"
,
n
,
i
,
topN
[
i
]
.
value
,
topN
[
i
-
1
]
.
value
)
}
}
// Verify we didn't lose any high value tokens by checking that
// all tokens after position n are <= the nth token
// Do this in chunks to avoid timeouts on large arrays
nthValue
:=
tokens
[
n
-
1
]
.
value
const
chunkSize
=
1000
for
start
:=
n
;
start
<
len
(
tokens
);
start
+=
chunkSize
{
end
:=
min
(
start
+
chunkSize
,
len
(
tokens
))
for
i
:=
start
;
i
<
end
;
i
++
{
if
tokens
[
i
]
.
value
>
nthValue
{
t
.
Fatalf
(
"found higher value token after position %d: tokens[%d].value = %.15f > %.15f"
,
n
,
i
,
tokens
[
i
]
.
value
,
nthValue
)
}
}
}
}
// loadTestLogits loads logit test data from testdata/logits.bin
func
loadTestLogits
(
t
*
testing
.
T
)
([]
float32
,
error
)
{
t
.
Helper
()
_
,
currFile
,
_
,
ok
:=
runtime
.
Caller
(
0
)
if
!
ok
{
return
nil
,
errors
.
New
(
"could not determine test file path"
)
}
testDataPath
:=
filepath
.
Join
(
filepath
.
Dir
(
currFile
),
"testdata"
,
"logits.bin"
)
file
,
err
:=
os
.
Open
(
testDataPath
)
if
err
!=
nil
{
return
nil
,
err
}
defer
file
.
Close
()
stat
,
err
:=
file
.
Stat
()
if
err
!=
nil
{
return
nil
,
err
}
numFloats
:=
stat
.
Size
()
/
4
// each float32 is 4 bytes
if
numFloats
*
4
!=
stat
.
Size
()
{
return
nil
,
errors
.
New
(
"logits.bin has invalid size: not a multiple of 4 bytes"
)
}
logits
:=
make
([]
float32
,
numFloats
)
for
i
:=
range
logits
{
var
val
uint32
if
err
:=
binary
.
Read
(
file
,
binary
.
LittleEndian
,
&
val
);
err
!=
nil
{
return
nil
,
err
}
logits
[
i
]
=
math
.
Float32frombits
(
val
)
}
if
len
(
logits
)
==
0
{
return
nil
,
errors
.
New
(
"logits.bin is empty"
)
}
return
logits
,
nil
}
func
BenchmarkTransforms
(
b
*
testing
.
B
)
{
func
BenchmarkTransforms
(
b
*
testing
.
B
)
{
// Generate random logits
// Generate random logits
tokens
:=
make
([]
token
,
1
<<
16
)
tokens
:=
make
([]
token
,
1
<<
16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment