Unverified Commit 91528575 authored by nunjunj's avatar nunjunj Committed by GitHub
Browse files

[Frontend] multiple sampling params support (#3570)

parent a22cdea3
import pytest
from vllm import LLM, SamplingParams
def test_multiple_sampling_params():
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95),
SamplingParams(temperature=0.7, top_p=0.95),
SamplingParams(temperature=0.99, top_p=0.95),
]
# Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(prompts) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.generate(prompts, sampling_params=sampling_params[:3])
# Single SamplingParams should be applied to every prompt
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
outputs = llm.generate(prompts, sampling_params=single_sampling_params)
assert len(prompts) == len(outputs)
# sampling_params is None, default params should be applied
outputs = llm.generate(prompts, sampling_params=None)
assert len(prompts) == len(outputs)
\ No newline at end of file
...@@ -127,7 +127,8 @@ class LLM: ...@@ -127,7 +127,8 @@ class LLM:
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -143,6 +144,9 @@ class LLM: ...@@ -143,6 +144,9 @@ class LLM:
prompts: A list of prompts to generate completions for. prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
...@@ -163,27 +167,33 @@ class LLM: ...@@ -163,27 +167,33 @@ class LLM:
and len(prompts) != len(prompt_token_ids)): and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids " raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.") "must be the same.")
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and sampling_params "
"must be the same.")
if multi_modal_data: if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16) multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
i] i]
self._add_request( self._add_request(
prompt, prompt,
sampling_params, sampling_params[i]
if isinstance(sampling_params, list) else sampling_params,
token_ids, token_ids,
lora_request=lora_request, lora_request=lora_request,
# Get ith image while maintaining the batch dim. # Get ith image while maintaining the batch dim.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment