Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
6208d622
Unverified
Commit
6208d622
authored
May 12, 2023
by
Woosuk Kwon
Committed by
GitHub
May 12, 2023
Browse files
Minor code cleaning for SamplingParams (#99)
parent
42f1042e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
49 deletions
+50
-49
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+50
-49
No files found.
cacheflow/sampling_params.py
View file @
6208d622
from
typing
import
Dict
,
Set
from
typing
import
Set
class
SamplingParams
:
class
SamplingParams
:
...
@@ -16,54 +16,6 @@ class SamplingParams:
...
@@ -16,54 +16,6 @@ class SamplingParams:
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
logprobs
:
int
=
0
,
logprobs
:
int
=
0
,
)
->
None
:
)
->
None
:
if
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
n
}
."
)
if
not
-
2.0
<=
presence_penalty
<=
2.0
:
raise
ValueError
(
f
"presence_penalty must be in [-2, 2], got
{
presence_penalty
}
."
)
if
not
-
2.0
<=
frequency_penalty
<=
2.0
:
raise
ValueError
(
f
"frequency_penalty must be in [-2, 2], got
{
frequency_penalty
}
."
)
if
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
temperature
}
."
)
if
not
0.0
<
top_p
<=
1.0
:
raise
ValueError
(
f
"top_p must be in (0, 1], got
{
top_p
}
."
)
if
top_k
<
-
1
or
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
top_k
}
."
)
if
max_tokens
<
1
:
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
max_tokens
}
."
)
if
logprobs
<
0
:
raise
ValueError
(
f
"logprobs must be non-negative, got
{
logprobs
}
."
)
if
use_beam_search
:
if
n
==
1
:
raise
ValueError
(
"n must be greater than 1 when using beam search."
)
if
temperature
>
0.0
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
elif
temperature
==
0.0
:
# Zero temperature means greedy sampling.
if
n
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling."
)
if
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
self
.
n
=
n
self
.
n
=
n
self
.
presence_penalty
=
presence_penalty
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
frequency_penalty
=
frequency_penalty
...
@@ -75,6 +27,55 @@ class SamplingParams:
...
@@ -75,6 +27,55 @@ class SamplingParams:
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verity_beam_search
()
elif
self
.
temperature
==
0.0
:
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
if
not
-
2.0
<=
self
.
presence_penalty
<=
2.0
:
raise
ValueError
(
"presence_penalty must be in [-2, 2], got "
f
"
{
self
.
presence_penalty
}
."
)
if
not
-
2.0
<=
self
.
frequency_penalty
<=
2.0
:
raise
ValueError
(
"frequency_penalty must be in [-2, 2], got "
f
"
{
self
.
frequency_penalty
}
."
)
if
self
.
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
self
.
temperature
}
."
)
if
not
0.0
<
self
.
top_p
<=
1.0
:
raise
ValueError
(
f
"top_p must be in (0, 1], got
{
self
.
top_p
}
."
)
if
self
.
top_k
<
-
1
or
self
.
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
self
.
top_k
}
."
)
if
self
.
max_tokens
<
1
:
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
if
self
.
logprobs
<
0
:
raise
ValueError
(
f
"logprobs must be non-negative, got
{
self
.
logprobs
}
."
)
def
_verity_beam_search
(
self
)
->
None
:
if
self
.
n
==
1
:
raise
ValueError
(
"n must be greater than 1 when using beam search."
)
if
self
.
temperature
>
0.0
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
self
.
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
if
self
.
n
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling."
)
if
self
.
top_p
<
1.0
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment