Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
c245b040
Unverified
Commit
c245b040
authored
Feb 27, 2025
by
Parth Sareen
Committed by
GitHub
Feb 27, 2025
Browse files
sample: remove transforms from greedy sampling (#9377)
parent
8b194b75
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
86 deletions
+54
-86
sample/samplers.go
sample/samplers.go
+16
-35
sample/samplers_test.go
sample/samplers_test.go
+38
-51
No files found.
sample/samplers.go
View file @
c245b040
...
@@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) {
...
@@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) {
if
idx
,
ok
:=
w
.
Take
();
ok
{
if
idx
,
ok
:=
w
.
Take
();
ok
{
return
int32
(
indices
[
idx
]),
nil
return
int32
(
indices
[
idx
]),
nil
}
}
return
-
1
,
errors
.
New
(
"weighed sampler failed, no valid token found"
)
return
-
1
,
errors
.
New
(
"weigh
t
ed sampler failed, no valid token found"
)
}
}
type
greedy
struct
{
type
greedy
struct
{}
transforms
[]
Transform
}
func
Greedy
(
transforms
...
Transform
)
Sampler
{
func
Greedy
()
Sampler
{
return
greedy
{
transforms
:
transforms
}
return
greedy
{}
}
}
// Sample returns the index of the maximum value in logits.
func
(
s
greedy
)
Sample
(
logits
[]
float32
)
(
int32
,
error
)
{
func
(
s
greedy
)
Sample
(
logits
[]
float32
)
(
int32
,
error
)
{
logits64
:=
make
([]
float64
,
len
(
logits
))
if
len
(
logits
)
==
0
{
for
i
,
v
:=
range
logits
{
return
-
1
,
errors
.
New
(
"no logits provided for greedy sampling"
)
logits64
[
i
]
=
float64
(
v
)
}
}
for
_
,
t
:=
range
s
.
transforms
{
maxIdx
:=
0
logits64
=
t
.
Apply
(
logits64
)
for
i
:=
range
logits
{
}
if
logits
[
i
]
>
logits
[
maxIdx
]
{
var
maxIdx
int
var
maxLogit
float64
for
i
,
logit
:=
range
logits64
{
if
logit
>
maxLogit
{
maxLogit
=
logit
maxIdx
=
i
maxIdx
=
i
}
}
}
}
if
maxLogit
==
math
.
Inf
(
-
1
)
{
return
-
1
,
errors
.
New
(
"no valid logits found for greedy sampling"
)
}
return
int32
(
maxIdx
),
nil
return
int32
(
maxIdx
),
nil
}
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func
NewSampler
(
temperature
float32
,
topK
int
,
topP
float32
,
minP
float32
,
seed
int
)
(
Sampler
,
error
)
{
func
NewSampler
(
temperature
float32
,
topK
int
,
topP
float32
,
minP
float32
,
seed
int
)
(
Sampler
,
error
)
{
transforms
:=
[]
Transform
{}
if
temperature
==
0
{
return
Greedy
(),
nil
}
if
temperature
<
0
||
temperature
>
2
{
if
temperature
<
0
||
temperature
>
2
{
return
nil
,
errors
.
New
(
"temperature must be between 0 and 2"
)
return
nil
,
errors
.
New
(
"temperature must be between 0 and 2"
)
}
}
if
temperature
!=
0
{
transforms
:=
[]
Transform
{
Temperature
(
temperature
)}
transforms
=
append
(
transforms
,
Temperature
(
temperature
))
}
if
topK
!=
0
{
if
topK
!=
0
{
if
topK
<=
0
{
if
topK
<=
0
{
...
@@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
...
@@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
transforms
=
append
(
transforms
,
MinP
(
minP
))
transforms
=
append
(
transforms
,
MinP
(
minP
))
}
}
if
len
(
transforms
)
==
0
{
if
seed
>=
0
{
return
nil
,
errors
.
New
(
"at least one transform is required"
)
}
if
temperature
==
0
{
return
Greedy
(
transforms
...
),
nil
}
if
seed
!=
0
{
seed64
:=
uint64
(
seed
)
seed64
:=
uint64
(
seed
)
return
Weighted
(
&
seed64
,
transforms
...
),
nil
return
Weighted
(
&
seed64
,
transforms
...
),
nil
}
}
...
...
sample/samplers_test.go
View file @
c245b040
...
@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
...
@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
callOrder
:
&
callOrder
,
callOrder
:
&
callOrder
,
}
}
got
,
err
:=
Greedy
(
mock1
,
mock2
,
mock3
)
.
Sample
(
input
)
_
,
err
:=
Weighted
(
nil
,
mock1
,
mock2
,
mock3
)
.
Sample
(
input
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
return
return
}
}
want
:=
int32
(
3
)
// Greedy sampler should pick highest logit
if
want
!=
got
{
t
.
Errorf
(
"index mismatch: want %d, got %d"
,
want
,
got
)
}
wantOrder
:=
[]
int
{
1
,
2
,
3
}
wantOrder
:=
[]
int
{
1
,
2
,
3
}
if
diff
:=
cmp
.
Diff
(
wantOrder
,
callOrder
);
diff
!=
""
{
if
diff
:=
cmp
.
Diff
(
wantOrder
,
callOrder
);
diff
!=
""
{
t
.
Errorf
(
"call order mismatch (-want +got):
\n
%s"
,
diff
)
t
.
Errorf
(
"call order mismatch (-want +got):
\n
%s"
,
diff
)
}
}
callOrder
=
nil
_
,
err
=
Weighted
(
nil
,
mock1
,
mock2
,
mock3
)
.
Sample
(
input
)
if
err
!=
nil
{
t
.
Error
(
err
)
return
}
wantOrder
=
[]
int
{
1
,
2
,
3
}
if
diff
:=
cmp
.
Diff
(
wantOrder
,
callOrder
);
diff
!=
""
{
t
.
Errorf
(
"call order mismatch (-want +got):
\n
%s"
,
diff
)
}
}
}
func
TestNewSampler
(
t
*
testing
.
T
)
{
func
TestNewSampler
(
t
*
testing
.
T
)
{
...
@@ -106,7 +89,8 @@ func TestNewSampler(t *testing.T) {
...
@@ -106,7 +89,8 @@ func TestNewSampler(t *testing.T) {
}{
}{
{
{
name
:
"no transforms"
,
name
:
"no transforms"
,
wantErr
:
true
,
// temperature is 0, so greedy should be used
wantErr
:
false
,
},
},
{
{
name
:
"temperature"
,
name
:
"temperature"
,
...
@@ -126,48 +110,51 @@ func TestNewSampler(t *testing.T) {
...
@@ -126,48 +110,51 @@ func TestNewSampler(t *testing.T) {
{
{
name
:
"top k"
,
name
:
"top k"
,
topK
:
10
,
topK
:
10
,
temperature
:
0.8
,
wantErr
:
false
,
wantErr
:
false
,
},
},
{
{
name
:
"invalid top k negative"
,
name
:
"invalid top k negative"
,
topK
:
-
1
,
topK
:
-
1
,
temperature
:
0.8
,
wantErr
:
true
,
wantErr
:
true
,
},
},
{
{
name
:
"top p"
,
name
:
"top p"
,
topP
:
0.9
,
topP
:
0.9
,
temperature
:
0.8
,
wantErr
:
false
,
wantErr
:
false
,
},
},
{
{
name
:
"invalid top p negative"
,
name
:
"invalid top p negative"
,
topP
:
-
0.1
,
topP
:
-
0.1
,
temperature
:
0.8
,
wantErr
:
true
,
wantErr
:
true
,
},
},
{
{
name
:
"invalid top p one"
,
name
:
"invalid top p one"
,
topP
:
1.0
,
topP
:
1.0
,
temperature
:
0.8
,
wantErr
:
true
,
wantErr
:
true
,
},
},
{
{
name
:
"min p"
,
name
:
"min p"
,
minP
:
0.2
,
minP
:
0.2
,
temperature
:
0.8
,
wantErr
:
false
,
wantErr
:
false
,
},
},
{
{
name
:
"invalid min p negative"
,
name
:
"invalid min p negative"
,
minP
:
-
0.1
,
minP
:
-
0.1
,
temperature
:
0.8
,
wantErr
:
true
,
wantErr
:
true
,
},
},
{
{
name
:
"invalid min p one"
,
name
:
"invalid min p one"
,
minP
:
1.0
,
minP
:
1.0
,
temperature
:
0.8
,
wantErr
:
true
,
wantErr
:
true
,
},
},
{
name
:
"seed"
,
seed
:
42
,
wantErr
:
true
,
// seed alone is not valid without other transforms
},
{
{
name
:
"default values"
,
name
:
"default values"
,
temperature
:
0.8
,
temperature
:
0.8
,
...
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
...
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
topP
:
0.0
,
topP
:
0.0
,
minP
:
0.0
,
minP
:
0.0
,
seed
:
0
,
seed
:
0
,
wantErr
:
tru
e
,
// all zeroes means no transforms
wantErr
:
fals
e
,
// all zeroes means no transforms
},
},
{
{
name
:
"all transforms"
,
name
:
"all transforms"
,
...
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
...
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
}
}
samplers
:=
map
[
string
]
Sampler
{
samplers
:=
map
[
string
]
Sampler
{
"Greedy"
:
Greedy
(
transforms
...
),
"Greedy"
:
Greedy
(),
"Weighted"
:
Weighted
(
nil
,
transforms
...
),
"Weighted"
:
Weighted
(
nil
,
transforms
...
),
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment