Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
91528575
Unverified
Commit
91528575
authored
Apr 20, 2024
by
nunjunj
Committed by
GitHub
Apr 20, 2024
Browse files
[Frontend] multiple sampling params support (#3570)
parent
a22cdea3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
10 deletions
+61
-10
tests/entrypoints/test_llm_generate.py
tests/entrypoints/test_llm_generate.py
+41
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+20
-10
No files found.
tests/entrypoints/test_llm_generate.py
0 → 100644
View file @
91528575
import
pytest
from
vllm
import
LLM
,
SamplingParams
def
test_multiple_sampling_params
():
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
[
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.3
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.7
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.99
,
top_p
=
0.95
),
]
# Multiple SamplingParams should be matched with each prompt
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
# Exception raised, if the size of params does not match the size of prompts
with
pytest
.
raises
(
ValueError
):
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
[:
3
])
# Single SamplingParams should be applied to every prompt
single_sampling_params
=
SamplingParams
(
temperature
=
0.3
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
single_sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
# sampling_params is None, default params should be applied
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
None
)
assert
len
(
prompts
)
==
len
(
outputs
)
\ No newline at end of file
vllm/entrypoints/llm.py
View file @
91528575
...
@@ -127,7 +127,8 @@ class LLM:
...
@@ -127,7 +127,8 @@ class LLM:
def
generate
(
def
generate
(
self
,
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -142,7 +143,10 @@ class LLM:
...
@@ -142,7 +143,10 @@ class LLM:
Args:
Args:
prompts: A list of prompts to generate completions for.
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
...
@@ -163,27 +167,33 @@ class LLM:
...
@@ -163,27 +167,33 @@ class LLM:
and
len
(
prompts
)
!=
len
(
prompt_token_ids
)):
and
len
(
prompts
)
!=
len
(
prompt_token_ids
)):
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
"must be the same."
)
"must be the same."
)
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
else
:
assert
prompt_token_ids
is
not
None
num_requests
=
len
(
prompt_token_ids
)
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and sampling_params "
"must be the same."
)
if
multi_modal_data
:
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
# Add requests to the engine.
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
else
:
assert
prompt_token_ids
is
not
None
num_requests
=
len
(
prompt_token_ids
)
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
i
]
self
.
_add_request
(
self
.
_add_request
(
prompt
,
prompt
,
sampling_params
,
sampling_params
[
i
]
if
isinstance
(
sampling_params
,
list
)
else
sampling_params
,
token_ids
,
token_ids
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
# Get ith image while maintaining the batch dim.
# Get ith image while maintaining the batch dim.
...
@@ -232,4 +242,4 @@ class LLM:
...
@@ -232,4 +242,4 @@ class LLM:
# This is necessary because some requests may be finished earlier than
# This is necessary because some requests may be finished earlier than
# its previous requests.
# its previous requests.
outputs
=
sorted
(
outputs
,
key
=
lambda
x
:
int
(
x
.
request_id
))
outputs
=
sorted
(
outputs
,
key
=
lambda
x
:
int
(
x
.
request_id
))
return
outputs
return
outputs
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment