Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
3209b490
Unverified
Commit
3209b490
authored
Jan 23, 2024
by
Nikola Borisov
Committed by
GitHub
Jan 23, 2024
Browse files
[Bugfix] fix crash if max_tokens=None (#2570)
parent
1e4277d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
2 deletions
+28
-2
tests/test_regression.py
tests/test_regression.py
+13
-0
tests/test_sampling_params.py
tests/test_sampling_params.py
+13
-0
vllm/sampling_params.py
vllm/sampling_params.py
+2
-2
No files found.
tests/test_regression.py
View file @
3209b490
...
@@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
...
@@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
assert
len
(
prompts
)
==
len
(
outputs
)
assert
len
(
prompts
)
==
len
(
outputs
)
def
test_max_tokens_none
():
sampling_params
=
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.1
,
max_tokens
=
None
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
)
prompts
=
[
"Just say hello!"
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
pytest
import
pytest
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
tests/test_sampling_params.py
0 → 100644
View file @
3209b490
"""Tests for the SamplingParams class.
"""
from
vllm
import
SamplingParams
def
test_max_tokens_none
():
"""max_tokens=None should be allowed"""
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.1
,
max_tokens
=
None
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
vllm/sampling_params.py
View file @
3209b490
...
@@ -108,7 +108,7 @@ class SamplingParams:
...
@@ -108,7 +108,7 @@ class SamplingParams:
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
Optional
[
int
]
=
16
,
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
...
@@ -183,7 +183,7 @@ class SamplingParams:
...
@@ -183,7 +183,7 @@ class SamplingParams:
if
not
0.0
<=
self
.
min_p
<=
1.0
:
if
not
0.0
<=
self
.
min_p
<=
1.0
:
raise
ValueError
(
"min_p must be in [0, 1], got "
raise
ValueError
(
"min_p must be in [0, 1], got "
f
"
{
self
.
min_p
}
."
)
f
"
{
self
.
min_p
}
."
)
if
self
.
max_tokens
<
1
:
if
self
.
max_tokens
is
not
None
and
self
.
max_tokens
<
1
:
raise
ValueError
(
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
if
self
.
logprobs
is
not
None
and
self
.
logprobs
<
0
:
if
self
.
logprobs
is
not
None
and
self
.
logprobs
<
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment