Unverified Commit 241ad7b3 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[ci] Fix sampler tests (#11922)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent d85c47d6
......@@ -214,6 +214,7 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- tests/samplers
- tests/conftest.py
commands:
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
......
......@@ -28,12 +28,13 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
TokensPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity)
identity, is_list_of)
logger = init_logger(__name__)
......@@ -886,6 +887,12 @@ class VllmRunner:
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment