Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
1b7433b7
Commit
1b7433b7
authored
Mar 12, 2025
by
ParthSareen
Committed by
Parth Sareen
Mar 12, 2025
Browse files
sample: use container/heap for top_k
parent
a70820da
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
217 additions
and
111 deletions
+217
-111
sample/testdata/logits.bin
sample/testdata/logits.bin
+1
-0
sample/transforms.go
sample/transforms.go
+97
-90
sample/transforms_test.go
sample/transforms_test.go
+119
-21
No files found.
sample/testdata/logits.bin
0 → 100644
View file @
1b7433b7
This diff is collapsed.
Click to expand it.
sample/transforms.go
View file @
1b7433b7
package
sample
package
sample
import
(
import
(
"container/heap"
"math"
"math"
"slices"
"slices"
)
)
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type
tokenHeap
[]
token
func
(
h
tokenHeap
)
Len
()
int
{
return
len
(
h
)
}
func
(
h
tokenHeap
)
Less
(
i
,
j
int
)
bool
{
return
h
[
i
]
.
value
<
h
[
j
]
.
value
}
// Use < for min-heap to track largest elements
func
(
h
tokenHeap
)
Swap
(
i
,
j
int
)
{
h
[
i
],
h
[
j
]
=
h
[
j
],
h
[
i
]
}
func
(
h
*
tokenHeap
)
Push
(
x
any
)
{
*
h
=
append
(
*
h
,
x
.
(
token
))
}
func
(
h
*
tokenHeap
)
Pop
()
any
{
old
:=
*
h
n
:=
len
(
old
)
x
:=
old
[
n
-
1
]
*
h
=
old
[
0
:
n
-
1
]
return
x
}
// temperature applies scaling and softmax to the logits
// temperature applies scaling and softmax to the logits
func
temperature
(
ts
[]
token
,
temp
float32
)
[]
token
{
func
temperature
(
ts
[]
token
,
temp
float32
)
[]
token
{
// Find max logit for numerical stability
// Find max logit for numerical stability
...
@@ -31,62 +51,33 @@ func temperature(ts []token, temp float32) []token {
...
@@ -31,62 +51,33 @@ func temperature(ts []token, temp float32) []token {
return
ts
return
ts
}
}
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
//
// The heap is represented as an array where for any node at index i:
// - Left child is at index 2i + 1
// - Right child is at index 2i + 2
// - Parent is at index (i-1)/2
//
// The function compares a node with its children and:
// 1. Finds the smallest value between the node and its children
// 2. If the node is not the smallest, swaps it with its smallest child
// 3. Continues this process down the affected path until the min-heap property is restored
func
siftDown
(
data
[]
token
,
start
,
end
int
)
{
root
:=
start
for
{
child
:=
2
*
root
+
1
if
child
>=
end
{
break
}
// Find smaller child (we want min heap)
if
child
+
1
<
end
&&
data
[
child
+
1
]
.
value
<
data
[
child
]
.
value
{
child
++
}
// Exit if root is already smaller than children
if
data
[
root
]
.
value
<=
data
[
child
]
.
value
{
break
}
// Swap with smaller child and continue
data
[
root
],
data
[
child
]
=
data
[
child
],
data
[
root
]
root
=
child
}
}
// topK limits the number of tokens considered to the k highest logits
// topK limits the number of tokens considered to the k highest logits
func
topK
(
ts
[]
token
,
k
int
)
[]
token
{
func
topK
(
ts
[]
token
,
k
int
)
[]
token
{
if
k
>=
len
(
ts
)
{
if
k
>=
len
(
ts
)
{
sortLogits
(
ts
)
return
ts
return
ts
}
}
// Heapify + siftDown - O(nlog(k))
// Build min-heap of first k elements
heap
:=
ts
[
:
k
]
for
i
:=
k
/
2
-
1
;
i
>=
0
;
i
--
{
siftDown
(
heap
,
i
,
k
)
}
// Process remaining elements - if larger than heap root, replace root
// Initialize min-heap with first k elements
h
:=
make
(
tokenHeap
,
k
)
copy
(
h
,
ts
[
:
k
])
heap
.
Init
(
&
h
)
// Process remaining elements
for
i
:=
k
;
i
<
len
(
ts
);
i
++
{
for
i
:=
k
;
i
<
len
(
ts
);
i
++
{
if
ts
[
i
]
.
value
>
h
eap
[
0
]
.
value
{
if
ts
[
i
]
.
value
>
h
[
0
]
.
value
{
heap
[
0
]
=
ts
[
i
]
heap
.
Pop
(
&
h
)
siftDown
(
heap
,
0
,
k
)
heap
.
Push
(
&
h
,
ts
[
i
]
)
}
}
}
}
slices
.
Reverse
(
heap
)
// Convert heap to sorted slice in descending order
result
:=
make
([]
token
,
k
)
for
i
:=
k
-
1
;
i
>=
0
;
i
--
{
result
[
i
]
=
heap
.
Pop
(
&
h
)
.
(
token
)
}
ts
=
heap
return
result
return
ts
}
}
// topP limits tokens to those with cumulative probability p
// topP limits tokens to those with cumulative probability p
...
@@ -135,61 +126,77 @@ func minP(ts []token, p float32) []token {
...
@@ -135,61 +126,77 @@ func minP(ts []token, p float32) []token {
return
ts
return
ts
}
}
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// partialSortLogits uses quickselect to efficiently find and sort the top n tokens
// sortLogits sorts implementation to sort tokens by logits using counting sort
func
partialSortLogits
(
ts
[]
token
,
n
int
)
[]
token
{
// counting sort is faster than built-in sort for this use case
if
n
>=
len
(
ts
)
{
func
sortLogits
(
tokens
[]
token
)
{
n
=
len
(
ts
)
if
len
(
tokens
)
<=
1
{
return
}
}
// Find max/min in a single pass
left
,
right
:=
0
,
len
(
ts
)
-
1
minLogit
,
maxLogit
:=
tokens
[
0
]
.
value
,
tokens
[
0
]
.
value
target
:=
n
-
1
for
_
,
t
:=
range
tokens
[
1
:
]
{
if
t
.
value
<
minLogit
{
minLogit
=
t
.
value
}
else
if
t
.
value
>
maxLogit
{
maxLogit
=
t
.
value
}
}
//
Calculate scaling to map to uint32 range
//
Quickselect algorithm to partition array around pivot
logitRange
:=
maxLogit
-
minLogit
for
left
<
right
{
if
logitRange
<
1e-6
{
// Choose middle element as pivot and move it to the end
return
// All values effectively equal
pivot
:=
left
+
(
right
-
left
)
/
2
}
ts
[
pivot
],
ts
[
right
]
=
ts
[
right
],
ts
[
pivot
]
//
Count frequencies directly from tokens
//
storeIndex tracks where to put next element greater than pivot
const
maxInt
=
(
1
<<
24
)
-
1
// Use 24 bits for good granularity
storeIndex
:=
left
var
counts
[
256
]
int
// For first byt
e
pivotValue
:=
ts
[
right
]
.
valu
e
// First pass: count frequencies
// Partition array into elements >= pivot and < pivot
for
_
,
t
:=
range
tokens
{
// Elements >= pivot go to the left side
// Map to [0, maxInt] range
for
i
:=
left
;
i
<
right
;
i
++
{
score
:=
min
(
uint32
((
t
.
value
-
minLogit
)
*
float32
(
maxInt
)
/
logitRange
),
maxInt
)
if
ts
[
i
]
.
value
>=
pivotValue
{
counts
[
score
>>
16
]
++
ts
[
storeIndex
],
ts
[
i
]
=
ts
[
i
],
ts
[
storeIndex
]
}
storeIndex
++
}
}
// Calculate offsets
// Move pivot to its final position
var
offset
int
ts
[
right
],
ts
[
storeIndex
]
=
ts
[
storeIndex
],
ts
[
right
]
for
i
:=
range
counts
{
count
:=
counts
[
i
]
// If pivot is at target position, we're done
counts
[
i
]
=
offset
// Otherwise recursively partition the half containing target
offset
+=
count
if
storeIndex
==
target
{
break
}
else
if
storeIndex
<
target
{
left
=
storeIndex
+
1
// Target is in right half
}
else
{
right
=
storeIndex
-
1
// Target is in left half
}
}
}
// Second pass: place elements in correct position
// Sort just the top n elements in descending order
output
:=
make
([]
token
,
len
(
tokens
))
slices
.
SortFunc
(
ts
[
:
n
],
func
(
a
,
b
token
)
int
{
// Track current positions
if
a
.
value
>
b
.
value
{
countsCopy
:=
counts
return
-
1
}
if
a
.
value
<
b
.
value
{
return
1
}
return
0
})
return
ts
[
:
n
]
}
for
i
,
t
:=
range
tokens
{
// sortLogits uses partialSortLogits to efficiently sort tokens
score
:=
min
(
uint32
((
t
.
value
-
minLogit
)
*
float32
(
maxInt
)
/
logitRange
),
maxInt
)
// It sorts approximately sqrt(len(tokens)) elements which balances
// between having enough tokens for sampling while avoiding full sort
func
sortLogits
(
ts
[]
token
)
{
// Use sqrt of token length as a heuristic for partial sort size
// This provides a good balance between performance and having enough tokens
n
:=
int
(
math
.
Sqrt
(
float64
(
len
(
ts
))))
+
1
pos
:=
countsCopy
[
score
>>
16
]
// Ensure we have at least 100 tokens and at most 1000
countsCopy
[
score
>>
16
]
++
switch
{
output
[
len
(
tokens
)
-
1
-
pos
]
=
tokens
[
i
]
case
n
<
100
:
n
=
100
case
n
>
1000
:
n
=
1000
}
}
copy
(
tokens
,
output
)
partialSortLogits
(
ts
,
n
)
}
}
sample/transforms_test.go
View file @
1b7433b7
package
sample
package
sample
import
(
import
(
"encoding/binary"
"errors"
"math"
"math"
"math/rand/v2"
"math/rand/v2"
"os"
"path/filepath"
"runtime"
"testing"
"testing"
)
)
// Helper to convert float
64
slice to logit slice
// Helper to convert float
32
slice to logit slice
func
toTokens
(
values
[]
float
64
)
[]
token
{
func
toTokens
(
values
[]
float
32
)
[]
token
{
tokens
:=
make
([]
token
,
len
(
values
))
tokens
:=
make
([]
token
,
len
(
values
))
for
i
,
v
:=
range
values
{
for
i
,
v
:=
range
values
{
tokens
[
i
]
=
token
{
tokens
[
i
]
=
token
{
id
:
int32
(
i
),
id
:
int32
(
i
),
value
:
float32
(
v
)
,
value
:
v
,
}
}
}
}
return
tokens
return
tokens
}
}
// Helper to compare logit slices
// Helper to compare logit slices
func
compareLogits
(
t
*
testing
.
T
,
name
string
,
want
[]
float
64
,
got
[]
token
)
{
func
compareLogits
(
t
*
testing
.
T
,
name
string
,
want
[]
float
32
,
got
[]
token
)
{
t
.
Helper
()
t
.
Helper
()
if
len
(
want
)
!=
len
(
got
)
{
if
len
(
want
)
!=
len
(
got
)
{
t
.
Errorf
(
"%s: length mismatch: want %d, got %d"
,
name
,
len
(
want
),
len
(
got
))
t
.
Errorf
(
"%s: length mismatch: want %d, got %d"
,
name
,
len
(
want
),
len
(
got
))
return
return
}
}
for
i
:=
range
want
{
for
i
:=
range
want
{
if
math
.
Abs
(
float64
(
got
[
i
]
.
value
)
-
want
[
i
])
>
1e-6
{
if
math
.
Abs
(
float64
(
got
[
i
]
.
value
-
want
[
i
])
)
>
1e-6
{
t
.
Errorf
(
"%s: index %d: want %f, got %f"
,
name
,
i
,
want
[
i
],
got
[
i
]
.
value
)
t
.
Errorf
(
"%s: index %d: want %f, got %f"
,
name
,
i
,
want
[
i
],
got
[
i
]
.
value
)
}
}
}
}
}
}
func
TestTemperatureAndSoftmax
(
t
*
testing
.
T
)
{
func
TestTemperatureAndSoftmax
(
t
*
testing
.
T
)
{
input
:=
[]
float
64
{
1
,
4
,
-
2
,
0
}
input
:=
[]
float
32
{
1
,
4
,
-
2
,
0
}
got
:=
temperature
(
toTokens
(
input
),
0.5
)
got
:=
temperature
(
toTokens
(
input
),
0.5
)
// Check probabilities sum to 1
// Check probabilities sum to 1
...
@@ -41,7 +46,7 @@ func TestTemperatureAndSoftmax(t *testing.T) {
...
@@ -41,7 +46,7 @@ func TestTemperatureAndSoftmax(t *testing.T) {
for
_
,
token
:=
range
got
{
for
_
,
token
:=
range
got
{
sum
+=
token
.
value
sum
+=
token
.
value
}
}
if
math
.
Abs
(
float64
(
sum
)
-
1.0
)
>
1e-6
{
if
math
.
Abs
(
float64
(
sum
-
1.0
)
)
>
1e-6
{
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
}
}
...
@@ -51,30 +56,31 @@ func TestTemperatureAndSoftmax(t *testing.T) {
...
@@ -51,30 +56,31 @@ func TestTemperatureAndSoftmax(t *testing.T) {
for
_
,
token
:=
range
got
{
for
_
,
token
:=
range
got
{
sum
+=
token
.
value
sum
+=
token
.
value
}
}
if
math
.
Abs
(
float64
(
sum
)
-
1.0
)
>
1e-6
{
if
math
.
Abs
(
float64
(
sum
-
1.0
)
)
>
1e-6
{
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
t
.
Errorf
(
"probabilities don't sum to 1: got %f"
,
sum
)
}
}
}
}
func
TestTopK
(
t
*
testing
.
T
)
{
func
TestTopK
(
t
*
testing
.
T
)
{
input
:=
[]
float
64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
input
:=
[]
float
32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
// Test k=3
// Test k=3
got
:=
topK
(
toTokens
(
input
),
3
)
got
:=
topK
(
toTokens
(
input
),
5
)
if
len
(
got
)
!=
3
{
if
len
(
got
)
!=
5
{
t
.
Errorf
(
"topK(
3
): wrong length: want
3
, got %d"
,
len
(
got
))
t
.
Errorf
(
"topK(
5
): wrong length: want
5
, got %d"
,
len
(
got
))
}
}
// Should keep highest 3 values
: 4, 2, 1
// Should keep highest 3 values
in descending order
want
:=
[]
float
64
{
4
,
2
,
1
}
want
:=
[]
float
32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
}
compareLogits
(
t
,
"topK(3)"
,
want
,
got
)
compareLogits
(
t
,
"topK(3)"
,
want
,
got
)
// Test k > len
got
=
topK
(
toTokens
(
input
),
20
)
got
=
topK
(
toTokens
(
input
),
10
)
if
len
(
got
)
!=
len
(
input
)
{
compareLogits
(
t
,
"topK(10)"
,
input
,
got
)
t
.
Errorf
(
"topK(20): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
}
}
}
func
TestTopP
(
t
*
testing
.
T
)
{
func
TestTopP
(
t
*
testing
.
T
)
{
input
:=
[]
float
64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
input
:=
[]
float
32
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
}
tokens
:=
toTokens
(
input
)
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax to get probabilities
// First apply temperature and softmax to get probabilities
...
@@ -92,7 +98,7 @@ func TestTopP(t *testing.T) {
...
@@ -92,7 +98,7 @@ func TestTopP(t *testing.T) {
}
}
func
TestMinP
(
t
*
testing
.
T
)
{
func
TestMinP
(
t
*
testing
.
T
)
{
input
:=
[]
float
64
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
,
3
}
input
:=
[]
float
32
{
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
4
,
3
}
tokens
:=
toTokens
(
input
)
tokens
:=
toTokens
(
input
)
// First apply temperature and softmax
// First apply temperature and softmax
...
@@ -108,7 +114,7 @@ func TestMinP(t *testing.T) {
...
@@ -108,7 +114,7 @@ func TestMinP(t *testing.T) {
}
}
func
TestSortLogits
(
t
*
testing
.
T
)
{
func
TestSortLogits
(
t
*
testing
.
T
)
{
input
:=
[]
float
64
{
3
,
1
,
4
,
2
,
-
1
,
0
,
-
2
}
input
:=
[]
float
32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
tokens
:=
toTokens
(
input
)
tokens
:=
toTokens
(
input
)
sortLogits
(
tokens
)
sortLogits
(
tokens
)
...
@@ -120,10 +126,102 @@ func TestSortLogits(t *testing.T) {
...
@@ -120,10 +126,102 @@ func TestSortLogits(t *testing.T) {
}
}
}
}
want
:=
[]
float
64
{
4
,
3
,
2
,
1
,
0
,
-
1
,
-
2
}
want
:=
[]
float
32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
,
0.043722924
,
0.036774673
,
0.026986899
,
0.01681367
,
0.0046718004
,
0.00412893
,
0.0030491839
}
compareLogits
(
t
,
"sortLogits"
,
want
,
tokens
)
compareLogits
(
t
,
"sortLogits"
,
want
,
tokens
)
}
}
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
func
TestSortLogitsWithRealData
(
t
*
testing
.
T
)
{
// This will be populated from testdata/logits.bin
// Format: 32-bit float array in binary format
logits
,
err
:=
loadTestLogits
(
t
)
if
err
!=
nil
{
t
.
Skipf
(
"Skipping real logit test: %v"
,
err
)
return
}
tokens
:=
toTokens
(
logits
)
sortLogits
(
tokens
)
// Calculate n for verification
n
:=
int
(
math
.
Sqrt
(
float64
(
len
(
tokens
))))
+
1
if
n
>
1000
{
n
=
1000
}
else
if
n
<
100
{
n
=
100
}
t
.
Logf
(
"Testing with %d tokens, partial sorting top %d"
,
len
(
tokens
),
n
)
// Only verify the top n elements are sorted (which is what we guarantee)
// This is much faster than checking the entire array
topN
:=
tokens
[
:
n
]
for
i
:=
1
;
i
<
len
(
topN
);
i
++
{
if
topN
[
i
]
.
value
>
topN
[
i
-
1
]
.
value
{
t
.
Fatalf
(
"top %d tokens not properly sorted at index %d: %.15f > %.15f"
,
n
,
i
,
topN
[
i
]
.
value
,
topN
[
i
-
1
]
.
value
)
}
}
// Verify we didn't lose any high value tokens by checking that
// all tokens after position n are <= the nth token
// Do this in chunks to avoid timeouts on large arrays
nthValue
:=
tokens
[
n
-
1
]
.
value
const
chunkSize
=
1000
for
start
:=
n
;
start
<
len
(
tokens
);
start
+=
chunkSize
{
end
:=
min
(
start
+
chunkSize
,
len
(
tokens
))
for
i
:=
start
;
i
<
end
;
i
++
{
if
tokens
[
i
]
.
value
>
nthValue
{
t
.
Fatalf
(
"found higher value token after position %d: tokens[%d].value = %.15f > %.15f"
,
n
,
i
,
tokens
[
i
]
.
value
,
nthValue
)
}
}
}
}
// loadTestLogits loads logit test data from testdata/logits.bin
func
loadTestLogits
(
t
*
testing
.
T
)
([]
float32
,
error
)
{
t
.
Helper
()
_
,
currFile
,
_
,
ok
:=
runtime
.
Caller
(
0
)
if
!
ok
{
return
nil
,
errors
.
New
(
"could not determine test file path"
)
}
testDataPath
:=
filepath
.
Join
(
filepath
.
Dir
(
currFile
),
"testdata"
,
"logits.bin"
)
file
,
err
:=
os
.
Open
(
testDataPath
)
if
err
!=
nil
{
return
nil
,
err
}
defer
file
.
Close
()
stat
,
err
:=
file
.
Stat
()
if
err
!=
nil
{
return
nil
,
err
}
numFloats
:=
stat
.
Size
()
/
4
// each float32 is 4 bytes
if
numFloats
*
4
!=
stat
.
Size
()
{
return
nil
,
errors
.
New
(
"logits.bin has invalid size: not a multiple of 4 bytes"
)
}
logits
:=
make
([]
float32
,
numFloats
)
for
i
:=
range
logits
{
var
val
uint32
if
err
:=
binary
.
Read
(
file
,
binary
.
LittleEndian
,
&
val
);
err
!=
nil
{
return
nil
,
err
}
logits
[
i
]
=
math
.
Float32frombits
(
val
)
}
if
len
(
logits
)
==
0
{
return
nil
,
errors
.
New
(
"logits.bin is empty"
)
}
return
logits
,
nil
}
func
BenchmarkTransforms
(
b
*
testing
.
B
)
{
func
BenchmarkTransforms
(
b
*
testing
.
B
)
{
// Generate random logits
// Generate random logits
tokens
:=
make
([]
token
,
1
<<
16
)
tokens
:=
make
([]
token
,
1
<<
16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment