Unverified Commit 6e865b6a authored by Chukwuma Nwaugha's avatar Chukwuma Nwaugha Committed by GitHub
Browse files

Refactor example prompts fixture (#29854)

Signed-off-by: nwaughac@gmail.com
parent d698bb38
...@@ -27,7 +27,7 @@ import threading ...@@ -27,7 +27,7 @@ import threading
from collections.abc import Generator from collections.abc import Generator
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Any, Callable, TypedDict, TypeVar, cast from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING
import numpy as np import numpy as np
import pytest import pytest
...@@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect ...@@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation.utils import GenerateOutput
logger = init_logger(__name__) logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__) _TEST_DIR = os.path.dirname(__file__)
...@@ -202,10 +207,7 @@ def dynamo_reset(): ...@@ -202,10 +207,7 @@ def dynamo_reset():
@pytest.fixture @pytest.fixture
def example_prompts() -> list[str]: def example_prompts() -> list[str]:
prompts = [] return [prompt for filename in _TEST_PROMPTS for prompt in _read_prompts(filename)]
for filename in _TEST_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture @pytest.fixture
...@@ -224,10 +226,7 @@ class DecoderPromptType(Enum): ...@@ -224,10 +226,7 @@ class DecoderPromptType(Enum):
@pytest.fixture @pytest.fixture
def example_long_prompts() -> list[str]: def example_long_prompts() -> list[str]:
prompts = [] return [prompt for filename in _LONG_PROMPTS for prompt in _read_prompts(filename)]
for filename in _LONG_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
...@@ -353,10 +352,13 @@ class HfRunner: ...@@ -353,10 +352,13 @@ class HfRunner:
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
model = auto_cls.from_pretrained( model = cast(
nn.Module,
auto_cls.from_pretrained(
model_name, model_name,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**model_kwargs, **model_kwargs,
),
) )
# in case some unquantized custom models are not in same dtype # in case some unquantized custom models are not in same dtype
...@@ -374,11 +376,13 @@ class HfRunner: ...@@ -374,11 +376,13 @@ class HfRunner:
self.model = model self.model = model
if not skip_tokenizer_init: if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
AutoTokenizer.from_pretrained(
model_name, model_name,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
)
# don't put this import at the top level # don't put this import at the top level
# it will call torch.cuda.device_count() # it will call torch.cuda.device_count()
...@@ -495,7 +499,7 @@ class HfRunner: ...@@ -495,7 +499,7 @@ class HfRunner:
outputs: list[tuple[list[list[int]], list[str]]] = [] outputs: list[tuple[list[list[int]], list[str]]] = []
for inputs in all_inputs: for inputs in all_inputs:
output_ids = self.model.generate( output_ids: torch.Tensor = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
**kwargs, **kwargs,
...@@ -505,8 +509,7 @@ class HfRunner: ...@@ -505,8 +509,7 @@ class HfRunner:
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
) )
output_ids = output_ids.cpu().tolist() outputs.append((output_ids.cpu().tolist(), output_str))
outputs.append((output_ids, output_str))
return outputs return outputs
def generate_greedy( def generate_greedy(
...@@ -574,7 +577,7 @@ class HfRunner: ...@@ -574,7 +577,7 @@ class HfRunner:
all_logprobs: list[list[torch.Tensor]] = [] all_logprobs: list[list[torch.Tensor]] = []
for inputs in all_inputs: for inputs in all_inputs:
output = self.model.generate( output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
...@@ -656,7 +659,7 @@ class HfRunner: ...@@ -656,7 +659,7 @@ class HfRunner:
all_output_strs: list[str] = [] all_output_strs: list[str] = []
for inputs in all_inputs: for inputs in all_inputs:
output = self.model.generate( output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment