Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
91528575
Unverified
Commit
91528575
authored
Apr 20, 2024
by
nunjunj
Committed by
GitHub
Apr 20, 2024
Browse files
[Frontend] multiple sampling params support (#3570)
parent
a22cdea3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
10 deletions
+61
-10
tests/entrypoints/test_llm_generate.py
tests/entrypoints/test_llm_generate.py
+41
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+20
-10
No files found.
tests/entrypoints/test_llm_generate.py
0 → 100644
View file @
91528575
import
pytest
from
vllm
import
LLM
,
SamplingParams
def
test_multiple_sampling_params
():
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
[
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.3
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.7
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.99
,
top_p
=
0.95
),
]
# Multiple SamplingParams should be matched with each prompt
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
# Exception raised, if the size of params does not match the size of prompts
with
pytest
.
raises
(
ValueError
):
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
[:
3
])
# Single SamplingParams should be applied to every prompt
single_sampling_params
=
SamplingParams
(
temperature
=
0.3
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
single_sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
# sampling_params is None, default params should be applied
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
None
)
assert
len
(
prompts
)
==
len
(
outputs
)
\ No newline at end of file
vllm/entrypoints/llm.py
View file @
91528575
...
...
@@ -127,7 +127,8 @@ class LLM:
def
generate
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -143,6 +144,9 @@ class LLM:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
...
...
@@ -163,27 +167,33 @@ class LLM:
and
len
(
prompts
)
!=
len
(
prompt_token_ids
)):
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
"must be the same."
)
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
else
:
assert
prompt_token_ids
is
not
None
num_requests
=
len
(
prompt_token_ids
)
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and sampling_params "
"must be the same."
)
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
else
:
assert
prompt_token_ids
is
not
None
num_requests
=
len
(
prompt_token_ids
)
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
self
.
_add_request
(
prompt
,
sampling_params
,
sampling_params
[
i
]
if
isinstance
(
sampling_params
,
list
)
else
sampling_params
,
token_ids
,
lora_request
=
lora_request
,
# Get ith image while maintaining the batch dim.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment