Unverified Commit 241ad7b3 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[ci] Fix sampler tests (#11922)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent d85c47d6
...@@ -214,6 +214,7 @@ steps: ...@@ -214,6 +214,7 @@ steps:
- vllm/model_executor/layers - vllm/model_executor/layers
- vllm/sampling_metadata.py - vllm/sampling_metadata.py
- tests/samplers - tests/samplers
- tests/conftest.py
commands: commands:
- pytest -v -s samplers - pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
......
...@@ -28,12 +28,13 @@ from vllm.distributed import (cleanup_dist_env_and_memory, ...@@ -28,12 +28,13 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) TokensPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompts)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity) identity, is_list_of)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -886,6 +887,12 @@ class VllmRunner: ...@@ -886,6 +887,12 @@ class VllmRunner:
beam_width: int, beam_width: int,
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search( outputs = self.model.beam_search(
prompts, prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment