Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
6208d622
Unverified
Commit
6208d622
authored
May 12, 2023
by
Woosuk Kwon
Committed by
GitHub
May 12, 2023
Browse files
Minor code cleaning for SamplingParams (#99)
parent
42f1042e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
49 deletions
+50
-49
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+50
-49
No files found.
cacheflow/sampling_params.py
View file @
6208d622
from
typing
import
Dict
,
Set
from
typing
import
Set
class
SamplingParams
:
...
...
@@ -16,54 +16,6 @@ class SamplingParams:
max_tokens
:
int
=
16
,
logprobs
:
int
=
0
,
)
->
None
:
if
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
if
not
-
2.0
<=
presence_penalty
<=
2.0
:
raise
ValueError
(
f
"presence_penalty must be in [-2, 2], got
{
presence_penalty
}
."
)
if
not
-
2.0
<=
frequency_penalty
<=
2.0
:
raise
ValueError
(
f
"frequency_penalty must be in [-2, 2], got
{
frequency_penalty
}
."
)
if
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
temperature
}
."
)
if
not
0.0
<
top_p
<=
1.0
:
raise
ValueError
(
f
"top_p must be in (0, 1], got
{
top_p
}
."
)
if
top_k
<
-
1
or
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
top_k
}
."
)
if
max_tokens
<
1
:
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
max_tokens
}
."
)
if
logprobs
<
0
:
raise
ValueError
(
f
"logprobs must be non-negative, got
{
logprobs
}
."
)
if
use_beam_search
:
if
n
==
1
:
raise
ValueError
(
"n must be greater than 1 when using beam search."
)
if
temperature
>
0.0
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
elif
temperature
==
0.0
:
# Zero temperature means greedy sampling.
if
n
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling."
)
if
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
self
.
n
=
n
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
...
...
@@ -75,6 +27,55 @@ class SamplingParams:
self
.
max_tokens
=
max_tokens
self
.
logprobs
=
logprobs
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verity_beam_search
()
elif
self
.
temperature
==
0.0
:
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
if
not
-
2.0
<=
self
.
presence_penalty
<=
2.0
:
raise
ValueError
(
"presence_penalty must be in [-2, 2], got "
f
"
{
self
.
presence_penalty
}
."
)
if
not
-
2.0
<=
self
.
frequency_penalty
<=
2.0
:
raise
ValueError
(
"frequency_penalty must be in [-2, 2], got "
f
"
{
self
.
frequency_penalty
}
."
)
if
self
.
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
self
.
temperature
}
."
)
if
not
0.0
<
self
.
top_p
<=
1.0
:
raise
ValueError
(
f
"top_p must be in (0, 1], got
{
self
.
top_p
}
."
)
if
self
.
top_k
<
-
1
or
self
.
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
self
.
top_k
}
."
)
if
self
.
max_tokens
<
1
:
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
if
self
.
logprobs
<
0
:
raise
ValueError
(
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
def
_verity_beam_search
(
self
)
->
None
:
if
self
.
n
==
1
:
raise
ValueError
(
"n must be greater than 1 when using beam search."
)
if
self
.
temperature
>
0.0
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
self
.
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
if
self
.
n
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling."
)
if
self
.
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment