Unverified Commit ef31eabc authored by zhou fan's avatar zhou fan Committed by GitHub
Browse files

[Model]: add some tests for aria model (#10770)


Signed-off-by: default avatarxffxff <1247714429@qq.com>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 995a1485
...@@ -656,6 +656,7 @@ class VllmRunner: ...@@ -656,6 +656,7 @@ class VllmRunner:
model_name: str, model_name: str,
task: TaskOption = "auto", task: TaskOption = "auto",
tokenizer_name: Optional[str] = None, tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due # Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit. # to kv cache size limit.
max_model_len: int = 1024, max_model_len: int = 1024,
...@@ -672,6 +673,7 @@ class VllmRunner: ...@@ -672,6 +673,7 @@ class VllmRunner:
model=model_name, model=model_name,
task=task, task=task,
tokenizer=tokenizer_name, tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True, trust_remote_code=True,
dtype=dtype, dtype=dtype,
swap_space=swap_space, swap_space=swap_space,
...@@ -842,6 +844,7 @@ class VllmRunner: ...@@ -842,6 +844,7 @@ class VllmRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs], ) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]: List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams( greedy_logprobs_params = SamplingParams(
...@@ -849,7 +852,8 @@ class VllmRunner: ...@@ -849,7 +852,8 @@ class VllmRunner:
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs, logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs, prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids,
stop=stop)
return self.generate_w_logprobs(prompts, return self.generate_w_logprobs(prompts,
greedy_logprobs_params, greedy_logprobs_params,
......
...@@ -8,6 +8,7 @@ from typing import Type ...@@ -8,6 +8,7 @@ from typing import Type
import pytest import pytest
import transformers import transformers
from transformers import AutoModelForVision2Seq from transformers import AutoModelForVision2Seq
from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, identity from vllm.utils import cuda_device_count_stateless, identity
...@@ -134,6 +135,35 @@ VLM_TEST_SETTINGS = { ...@@ -134,6 +135,35 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
#### Extended model tests #### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
max_num_seqs=2,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<vlm_image>Please describe the image shortly.",
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
),
large_gpu_mark(min_gb=64),
],
),
"blip2": VLMTestInfo( "blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"], models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
......
...@@ -29,6 +29,8 @@ def run_test( ...@@ -29,6 +29,8 @@ def run_test(
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None], comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int], limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]], model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
...@@ -50,11 +52,14 @@ def run_test( ...@@ -50,11 +52,14 @@ def run_test(
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
vllm_kwargs = {} vllm_kwargs: Dict[str, Any] = {}
if get_stop_token_ids is not None: if get_stop_token_ids is not None:
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
if stop_str:
vllm_kwargs["stop"] = stop_str
with vllm_runner(model, with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len, max_model_len=max_model_len,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
dtype=dtype, dtype=dtype,
...@@ -85,6 +90,8 @@ def run_test( ...@@ -85,6 +90,8 @@ def run_test(
hf_kwargs = {} hf_kwargs = {}
if use_tokenizer_eos: if use_tokenizer_eos:
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
if stop_str:
hf_kwargs["stop_strings"] = stop_str
with hf_model, torch.no_grad(): with hf_model, torch.no_grad():
for prompts, media in inputs: for prompts, media in inputs:
......
...@@ -97,6 +97,9 @@ class VLMTestInfo(NamedTuple): ...@@ -97,6 +97,9 @@ class VLMTestInfo(NamedTuple):
# Optional callable which gets a list of token IDs from the model tokenizer # Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None
# Exposed options for HF runner # Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None model_kwargs: Optional[Dict[str, Any]] = None
...@@ -148,6 +151,8 @@ class VLMTestInfo(NamedTuple): ...@@ -148,6 +151,8 @@ class VLMTestInfo(NamedTuple):
marks: Optional[List[MarkDecorator]] = None marks: Optional[List[MarkDecorator]] = None
tokenizer_mode: str = "auto"
def get_non_parametrized_runner_kwargs(self): def get_non_parametrized_runner_kwargs(self):
"""Returns a dictionary of expandable kwargs for items that are used """Returns a dictionary of expandable kwargs for items that are used
in all test types, which are NOT used when creating the parametrized in all test types, which are NOT used when creating the parametrized
...@@ -166,8 +171,10 @@ class VLMTestInfo(NamedTuple): ...@@ -166,8 +171,10 @@ class VLMTestInfo(NamedTuple):
"postprocess_inputs": self.postprocess_inputs, "postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator, "comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids, "get_stop_token_ids": self.get_stop_token_ids,
"stop_str": self.stop_str,
"model_kwargs": self.model_kwargs, "model_kwargs": self.model_kwargs,
"patch_hf_runner": self.patch_hf_runner, "patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment