Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
4aeb67ef
Commit
4aeb67ef
authored
Mar 12, 2025
by
ParthSareen
Committed by
Parth Sareen
Mar 12, 2025
Browse files
sample: do all sorting in topK
parent
3ba91634
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
25 deletions
+35
-25
sample/samplers.go
sample/samplers.go
+2
-5
sample/transforms.go
sample/transforms.go
+11
-16
sample/transforms_test.go
sample/transforms_test.go
+22
-4
No files found.
sample/samplers.go
View file @
4aeb67ef
...
@@ -84,11 +84,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
...
@@ -84,11 +84,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return
greedy
(
tokens
),
nil
return
greedy
(
tokens
),
nil
}
}
if
s
.
topK
>
0
{
// topK also sorts the tokens in descending order of logits
tokens
=
topK
(
tokens
,
s
.
topK
)
tokens
=
topK
(
tokens
,
s
.
topK
)
}
else
{
sortLogits
(
tokens
)
}
// token logit values are updated to probabilities
// token logit values are updated to probabilities
tokens
=
temperature
(
tokens
,
s
.
temperature
)
tokens
=
temperature
(
tokens
,
s
.
temperature
)
...
...
sample/transforms.go
View file @
4aeb67ef
...
@@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token {
...
@@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token {
// topK limits the number of tokens considered to the k highest logits
// topK limits the number of tokens considered to the k highest logits
func
topK
(
ts
[]
token
,
k
int
)
[]
token
{
func
topK
(
ts
[]
token
,
k
int
)
[]
token
{
if
k
>=
len
(
ts
)
{
if
k
>=
len
(
ts
)
||
k
<=
0
{
sortLogits
(
ts
)
slices
.
SortFunc
(
ts
,
func
(
a
,
b
token
)
int
{
switch
{
case
a
.
value
<
b
.
value
:
return
1
case
a
.
value
>
b
.
value
:
return
-
1
default
:
return
0
}
})
return
ts
return
ts
}
}
...
@@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token {
...
@@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token {
ts
=
validTokens
ts
=
validTokens
return
ts
return
ts
}
}
// sortLogits sorts the tokens in descending order of logits
func
sortLogits
(
ts
[]
token
)
{
slices
.
SortFunc
(
ts
,
func
(
a
,
b
token
)
int
{
switch
{
case
a
.
value
<
b
.
value
:
return
1
case
a
.
value
>
b
.
value
:
return
-
1
default
:
return
0
}
})
}
sample/transforms_test.go
View file @
4aeb67ef
...
@@ -59,7 +59,7 @@ func TestTemperatureAndSoftmax(t *testing.T) {
...
@@ -59,7 +59,7 @@ func TestTemperatureAndSoftmax(t *testing.T) {
func
TestTopK
(
t
*
testing
.
T
)
{
func
TestTopK
(
t
*
testing
.
T
)
{
input
:=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
input
:=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
// Test k=
3
// Test k=
5
got
:=
topK
(
toTokens
(
input
),
5
)
got
:=
topK
(
toTokens
(
input
),
5
)
if
len
(
got
)
!=
5
{
if
len
(
got
)
!=
5
{
t
.
Errorf
(
"topK(5): wrong length: want 5, got %d"
,
len
(
got
))
t
.
Errorf
(
"topK(5): wrong length: want 5, got %d"
,
len
(
got
))
...
@@ -72,6 +72,24 @@ func TestTopK(t *testing.T) {
...
@@ -72,6 +72,24 @@ func TestTopK(t *testing.T) {
if
len
(
got
)
!=
len
(
input
)
{
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(20): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
t
.
Errorf
(
"topK(20): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
}
}
// Test k=-1
input
=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
want
=
[]
float32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
,
0.043722924
,
0.036774673
,
0.026986899
,
0.01681367
,
0.0046718004
,
0.00412893
,
0.0030491839
}
got
=
topK
(
toTokens
(
input
),
-
1
)
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(-1): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
}
compareLogits
(
t
,
"topK(-1)"
,
want
,
got
)
// Test k=0
input
=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
want
=
[]
float32
{
0.27755088
,
0.20409796
,
0.15720603
,
0.08582123
,
0.045046154
,
0.043722924
,
0.036774673
,
0.026986899
,
0.01681367
,
0.0046718004
,
0.00412893
,
0.0030491839
}
got
=
topK
(
toTokens
(
input
),
0
)
if
len
(
got
)
!=
len
(
input
)
{
t
.
Errorf
(
"topK(-1): wrong length: want %d, got %d"
,
len
(
input
),
len
(
got
))
}
compareLogits
(
t
,
"topK(-1)"
,
want
,
got
)
}
}
func
TestTopP
(
t
*
testing
.
T
)
{
func
TestTopP
(
t
*
testing
.
T
)
{
...
@@ -80,7 +98,7 @@ func TestTopP(t *testing.T) {
...
@@ -80,7 +98,7 @@ func TestTopP(t *testing.T) {
// First apply temperature and softmax to get probabilities
// First apply temperature and softmax to get probabilities
tokens
=
temperature
(
tokens
,
1
)
tokens
=
temperature
(
tokens
,
1
)
sortLogits
(
tokens
)
tokens
=
topK
(
tokens
,
20
)
// Then apply topP
// Then apply topP
got
:=
topP
(
tokens
,
0.95
)
got
:=
topP
(
tokens
,
0.95
)
...
@@ -112,7 +130,7 @@ func TestSortLogits(t *testing.T) {
...
@@ -112,7 +130,7 @@ func TestSortLogits(t *testing.T) {
input
:=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
input
:=
[]
float32
{
0.026986899
,
0.043722924
,
0.036774673
,
0.27755088
,
0.0046718004
,
0.08582123
,
0.20409796
,
0.00412893
,
0.15720603
,
0.045046154
,
0.0030491839
,
0.01681367
}
tokens
:=
toTokens
(
input
)
tokens
:=
toTokens
(
input
)
sortLogits
(
tokens
)
tokens
=
topK
(
tokens
,
20
)
for
i
:=
1
;
i
<
len
(
tokens
);
i
++
{
for
i
:=
1
;
i
<
len
(
tokens
);
i
++
{
if
tokens
[
i
]
.
value
>
tokens
[
i
-
1
]
.
value
{
if
tokens
[
i
]
.
value
>
tokens
[
i
-
1
]
.
value
{
...
@@ -173,7 +191,7 @@ func BenchmarkTransforms(b *testing.B) {
...
@@ -173,7 +191,7 @@ func BenchmarkTransforms(b *testing.B) {
b
.
ResetTimer
()
b
.
ResetTimer
()
for
b
.
Loop
()
{
for
b
.
Loop
()
{
copy
(
tokensCopy
,
tokens
)
copy
(
tokensCopy
,
tokens
)
sortLogits
(
tokensCopy
)
topK
(
tokensCopy
,
200000
)
}
}
})
})
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment