Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom input builders for edge-cases in different models."""
from io import BytesIO
from typing import Callable
import requests
from PIL import Image
from vllm.assets.image import ImageAsset
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video)
......@@ -118,9 +115,9 @@ def different_patch_input_cases_internvl():
def windows_attention_image_qwen2_5_vl():
# image from regression issue: https://github.com/vllm-project/vllm/issues/15122
image_url = "https://aomediacodec.github.io/av1-avif/testFiles/Link-U/hato.jpg"
image = Image.open(BytesIO(requests.get(image_url).content))
# image from regression issue: https://github.com/vllm-project/vllm/issues/15122 # noqa: E501
image = ImageAsset("hato").pil_image
question = "Describe the image."
img_prompt = "<|vision_start|><|image_pad|><|vision_end|>"
......
......@@ -10,6 +10,7 @@ from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import PIL.Image
import pytest
import regex as re
import torch
......@@ -19,7 +20,6 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
from transformers.video_utils import VideoMetadata
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
......@@ -343,7 +343,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
hf_processor = hf_model.processor
patch_padding_side(hf_processor)
def processor(*args, text="", images=None, **kwargs):
if images is None:
......@@ -812,6 +811,63 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.llm.get_output_embeddings()
def processor(*args, text="", images=None, videos=None, **kwargs):
if images is None:
images = []
else:
images = [images] if isinstance(images, Image) else images
if videos is None:
videos = []
else:
videos = [videos] if isinstance(videos, np.ndarray) else videos
videos = [[PIL.Image.fromarray(frame) for frame in vid]
for vid in videos]
prompt_start_and_end = {
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
"llama":
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
}
for start, end in prompt_start_and_end.values():
if start in text and end in text:
text = text.split(start)[1].split(end)[0]
break
images_message = [{"type": "image", "image": img} for img in images]
videos_message = [{"type": "video", "video": vid} for vid in videos]
messages = [{
"role":
"user",
"content": [
*images_message,
*videos_message,
{
"type": "text",
"text": text
},
],
}]
input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs(
messages=messages, enable_thinking=True)
inputs = {
"inputs": input_ids,
"pixel_values": pixel_values,
"grid_thws": grid_thws,
}
return BatchFeature(data=inputs, tensor_type="pt")
hf_model.processor = processor
return hf_model
def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
thinker = hf_model.model.thinker
......
......@@ -15,8 +15,9 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)
......@@ -65,6 +66,8 @@ def _test_processing_correctness(
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
......@@ -73,8 +76,7 @@ def _test_processing_correctness(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity_gb=2048)
cache = MultiModalProcessorOnlyCache(model_config)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
......@@ -104,7 +106,7 @@ def _test_processing_correctness(
partial(random_video,
rng,
min_frames=2,
max_frames=8,
max_frames=16,
min_wh=128,
max_wh=256),
"audio":
......@@ -162,8 +164,10 @@ def _test_processing_correctness(
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"donut": False,
"mllama": False,
"ovis": False,
"ovis2_5": False,
"paligemma": False,
"ultravox": False,
"whisper": False,
......@@ -265,60 +269,72 @@ def _test_processing_correctness_one(
# yapf: disable
@pytest.mark.parametrize("model_id", [
os.path.join(models_path_prefix, "rhymes-ai/Aria"),
os.path.join(models_path_prefix, "CohereForAI/aya-vision-8b"),
os.path.join(models_path_prefix, "Salesforce/blip2-opt-2.7b"),
os.path.join(models_path_prefix, "facebook/chameleon-7b"),
os.path.join(models_path_prefix, "deepseek-ai/deepseek-vl2-tiny"),
os.path.join(models_path_prefix, "microsoft/Florence-2-base"),
os.path.join(models_path_prefix, "adept/fuyu-8b"),
os.path.join(models_path_prefix, "google/gemma-3-4b-it"),
os.path.join(models_path_prefix, "google/gemma-3n-E2B-it"),
os.path.join(models_path_prefix, "zai-org/glm-4v-9b"),
os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking"),
os.path.join(models_path_prefix, "ibm-granite/granite-speech-3.3-2b"),
os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"),
os.path.join(models_path_prefix, "internlm/Intern-S1"),
os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"),
os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B"),
os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"),
os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct"),
os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"),
os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
os.path.join(models_path_prefix, "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"),
os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf"),
os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf"),
os.path.join(models_path_prefix, "llava-hf/LLaVA-NeXT-Video-7B-hf"),
os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"),
os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct"),
os.path.join(models_path_prefix, "TIGER-Lab/Mantis-8B-siglip-llama3"),
os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"),
os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"),
os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"),
os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-VL-01"),
os.path.join(models_path_prefix, "allenai/Molmo-7B-D-0924"),
os.path.join(models_path_prefix, "allenai/Molmo-7B-O-0924"),
os.path.join(models_path_prefix, "nvidia/NVLM-D-72B"),
os.path.join(models_path_prefix, "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"),
os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Gemma2-9B"),
os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Llama3.2-3B"),
os.path.join(models_path_prefix, "AIDC-AI/Ovis2-1B"),
os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"),
os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448"),
os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"),
os.path.join(models_path_prefix, "microsoft/Phi-4-multimodal-instruct"),
os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
os.path.join(models_path_prefix, "mistral-community/pixtral-12b"),
os.path.join(models_path_prefix, "Qwen/Qwen-VL-Chat"),
os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct"),
os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct"),
os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct"),
os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-3B"),
os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"),
os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"),
os.path.join(models_path_prefix, "openai/whisper-large-v3"),
os.path.join(models_path_prefix, "omni-research/Tarsier-7b"),
os.path.join(models_path_prefix, "omni-research/Tarsier2-Recap-7b"),
os.path.join(models_path_prefix,"rhymes-ai/Aria"),
os.path.join(models_path_prefix,"CohereForAI/aya-vision-8b"),
os.path.join(models_path_prefix,"Salesforce/blip2-opt-2.7b"),
os.path.join(models_path_prefix,"facebook/chameleon-7b"),
os.path.join(models_path_prefix,"CohereLabs/command-a-vision-07-2025"),
os.path.join(models_path_prefix,"deepseek-ai/deepseek-vl2-tiny"),
os.path.join(models_path_prefix,"naver-clova-ix/donut-base-finetuned-docvqa"),
os.path.join(models_path_prefix,"baidu/ERNIE-4.5-VL-28B-A3B-PT"),
os.path.join(models_path_prefix,"microsoft/Florence-2-base"),
os.path.join(models_path_prefix,"adept/fuyu-8b"),
os.path.join(models_path_prefix,"google/gemma-3-4b-it"),
os.path.join(models_path_prefix,"google/gemma-3n-E2B-it"),
os.path.join(models_path_prefix,"zai-org/glm-4v-9b"),
os.path.join(models_path_prefix,"zai-org/GLM-4.1V-9B-Thinking"),
os.path.join(models_path_prefix,"zai-org/GLM-4.5V"),
os.path.join(models_path_prefix,"ibm-granite/granite-speech-3.3-2b"),
os.path.join(models_path_prefix,"h2oai/h2ovl-mississippi-800m"),
os.path.join(models_path_prefix,"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"),
os.path.join(models_path_prefix,"HuggingFaceM4/Idefics3-8B-Llama3"),
os.path.join(models_path_prefix,"internlm/Intern-S1)",
os.path.join(models_path_prefix,"OpenGVLab/InternVL2-1B"),
os.path.join(models_path_prefix,"OpenGVLab/InternVL3-1B)",
os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-1B"),
os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"),
os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-30B-A3B"),
os.path.join(models_path_prefix,"Kwai-Keye/Keye-VL-8B-Preview"),
os.path.join(models_path_prefix,"moonshotai/Kimi-VL-A3B-Instruct"),
os.path.join(models_path_prefix,"meta-llama/Llama-4-Scout-17B-16E-Instruct"),
os.path.join(models_path_prefix,"llava-hf/llava-1.5-7b-hf"),
os.path.join(models_path_prefix,"llava-hf/llava-v1.6-mistral-7b-hf"),
os.path.join(models_path_prefix,"llava-hf/LLaVA-NeXT-Video-7B-hf"),
os.path.join(models_path_prefix,"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"),
os.path.join(models_path_prefix,"meta-llama/Llama-3.2-11B-Vision-Instruct"),
os.path.join(models_path_prefix,"TIGER-Lab/Mantis-8B-siglip-llama3"),
os.path.join(models_path_prefix,"openbmb/MiniCPM-Llama3-V-2_5"),
os.path.join(models_path_prefix,"openbmb/MiniCPM-o-2_6"),
os.path.join(models_path_prefix,"openbmb/MiniCPM-V-2_6"),
os.path.join(models_path_prefix,"MiniMaxAI/MiniMax-VL-01"),
os.path.join(models_path_prefix,"allenai/Molmo-7B-D-0924"),
os.path.join(models_path_prefix,"allenai/Molmo-7B-O-0924"),
os.path.join(models_path_prefix,"nvidia/NVLM-D-72B"),
os.path.join(models_path_prefix,"nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"),
os.path.join(models_path_prefix,"AIDC-AI/Ovis1.6-Gemma2-9B"),
os.path.join(models_path_prefix,"AIDC-AI/Ovis1.6-Llama3.2-3B"),
os.path.join(models_path_prefix,"AIDC-AI/Ovis2-1B"),
os.path.join(models_path_prefix,"AIDC-AI/Ovis2.5-2B"),
os.path.join(models_path_prefix,"google/paligemma-3b-mix-224"),
os.path.join(models_path_prefix,"google/paligemma2-3b-ft-docci-448"),
os.path.join(models_path_prefix,"microsoft/Phi-3.5-vision-instruct"),
os.path.join(models_path_prefix,"microsoft/Phi-4-multimodal-instruct"),
os.path.join(models_path_prefix,"mistralai/Pixtral-12B-2409"),
os.path.join(models_path_prefix,"mistral-community/pixtral-12b"),
os.path.join(models_path_prefix,"Qwen/Qwen-VL-Chat"),
os.path.join(models_path_prefix,"Qwen/Qwen2-VL-2B-Instruct"),
os.path.join(models_path_prefix,"Qwen/Qwen2.5-VL-3B-Instruct"),
os.path.join(models_path_prefix,"Qwen/Qwen2-Audio-7B-Instruct"),
os.path.join(models_path_prefix,"Qwen/Qwen2.5-Omni-3B"),
os.path.join(models_path_prefix,"YannQi/R-4B"),
os.path.join(models_path_prefix,"Skywork/Skywork-R1V-38B"),
os.path.join(models_path_prefix,"HuggingFaceTB/SmolVLM2-2.2B-Instruct"),
os.path.join(models_path_prefix,"stepfun-ai/step3"),
os.path.join(models_path_prefix,"fixie-ai/ultravox-v0_5-llama-3_2-1b"),
os.path.join(models_path_prefix,"openai/whisper-large-v3"),
os.path.join(models_path_prefix,"omni-research/Tarsier-7b"),
os.path.join(models_path_prefix,"omni-research/Tarsier2-Recap-7b"),
os.path.join(models_path_prefix,"mistralai/Voxtral-Mini-3B-2507"),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
......@@ -372,10 +388,16 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
assert a_rest == b_rest, msg
a_data = a["mm_kwargs"].get_data()
b_data = b["mm_kwargs"].get_data()
for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)
a_data.pop(key, None)
b_data.pop(key, None)
assert a == b, msg
assert a_data == b_data, msg
......@@ -45,7 +45,8 @@ def test_processor_override(
video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token)
video_tok_count = processed_inputs["prompt_token_ids"].count(
video_token_id)
grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0]
grid_t, _, _ = processed_inputs["mm_kwargs"].get_data(
)["video_grid_thw"][0]
assert grid_t == expected_grid_t
assert video_tok_count == expected_toks_per_frame * grid_t
......@@ -108,7 +108,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches
......
......@@ -70,7 +70,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches
......
......@@ -53,14 +53,14 @@ def test_processor_override(
prompt = encode_tokens(tokenizer, prompt)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"]
mm_data = processed_inputs["mm_kwargs"].get_data()
# place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"]
aspect_ratios = mm_data["aspect_ratios"]
num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1:
......@@ -82,6 +82,6 @@ def test_processor_override(
num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum()
\ No newline at end of file
== sum(mm_data["patches_per_image"]) * num_patches_per_chunk
assert len(mm_data["pixel_values"]) \
== sum(mm_data["patches_per_image"])
......@@ -51,18 +51,18 @@ def test_profiling(
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs
mm_kwargs = processor.apply(
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
)["mm_kwargs"].get_data()
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
num_tiles = [[t] for t in mm_data.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
......
......@@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int):
hf_config = ctx.get_hf_config(Llama4Config)
mm_kwargs = processor.apply(
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
)["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"])
chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
0][1] # x-y seperator tokens
num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][
1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end
......
......@@ -70,7 +70,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches
......
......@@ -50,7 +50,8 @@ def test_processor_override(
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values"].shape
assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial
from typing import Any, Union
from unittest.mock import patch
import numpy as np
import pytest
import torch.nn as nn
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import InputProcessingContext
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
from vllm.utils import is_list_of
from ...conftest import VllmRunner
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ..utils import dummy_hf_overrides
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"MiniMaxVL01ForConditionalGeneration": "broken model",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",
......@@ -39,7 +39,12 @@ ARCH_NEEDS_EXTRAS = [
"MiniCPMV",
"PaliGemmaForConditionalGeneration",
]
REPO_ID_TO_SKIP = {"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test"}
REPO_ID_TO_SKIP = {
"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test",
# FIXME(Isotr0py): enable GPT-OSS based InternVL3.5 model
# after support PP for GPT-OSS
"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview": "Broken model",
}
ImageInput = list[Image.Image]
VideoInput = Union[list[Image.Image], list[np.ndarray],
......@@ -128,11 +133,32 @@ def create_batched_mm_kwargs(
)["mm_kwargs"]
items = [
item for modality in supported_mm_limits
for item in mm_kwargs.get_items(modality)
for item in mm_kwargs[modality]
]
return group_mm_kwargs_by_modality(items)
@contextmanager
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config):
with set_default_torch_dtype(model_config.dtype):
model = model_cls(vllm_config=vllm_config)
yield model
del model
cleanup_dist_env_and_memory()
def get_model_id_to_test(
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
filtered_results = []
......@@ -148,12 +174,10 @@ def get_model_id_to_test(
return filtered_results
@pytest.mark.core_model
@pytest.mark.parametrize(
"model_arch, model_id",
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, model_id: str,
vllm_runner: type[VllmRunner], monkeypatch):
def test_model_tensor_schema(model_arch: str, model_id: str):
if model_arch in ARCH_TO_SKIP:
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
if model_id in REPO_ID_TO_SKIP:
......@@ -174,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
hf_overrides=hf_overrides_fn,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
if not any(
hasattr(model_cls, f"_parse_and_validate_{m}_input")
for m in ["image", "video", "audio"]):
inputs_parse_methods = []
for attr_name in dir(model_cls):
attr = getattr(model_cls, attr_name)
if hasattr(attr, "__annotations__"):
return_type = attr.__annotations__.get("return", None)
if return_type is not None and "Input" in str(return_type):
inputs_parse_methods.append(attr_name)
if not any(inputs_parse_methods):
pytest.skip(f"{model_arch} does not support tensor schema validation.")
ctx = InputProcessingContext(
......@@ -194,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
10 * GiB_bytes,
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1), monkeypatch.context() as m):
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
# TODO(Isotr0py): Can we avoid initializing engine?
with (
set_default_torch_num_threads(1),
vllm_runner(
model_id,
tokenizer_name=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides_fn,
limit_mm_per_prompt=limit_mm_per_prompt,
enforce_eager=True,
) as vllm_model,
):
model_config = vllm_model.llm.llm_engine.model_config
llm_engine = vllm_model.llm.llm_engine
if hasattr(llm_engine, "processor"):
# v1 processor
mm_registry = llm_engine.processor.mm_registry
else:
# v0 input_preprocessor
mm_registry = llm_engine.input_preprocessor.mm_registry
processor = mm_registry.create_processor(model_config)
def validate_model_input(model, modality: str,
mm_kwargs: MultiModalKwargs):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
valid_func = partial(validate_model_input,
modality=modality,
mm_kwargs=mm_kwargs)
vllm_model.apply_model(valid_func)
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
processor = factories.build_processor(ctx, cache=None)
with initialize_dummy_model(model_cls, model_config) as model:
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
for method_name in inputs_parse_methods:
print(f"Testing `{method_name}` with modality={modality} "
f"and mm_kwargs{list(mm_kwargs.keys())}")
getattr(model, method_name)(modality=modality, **mm_kwargs)
......@@ -33,7 +33,7 @@ from ...utils import models_path_prefix
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
......@@ -58,9 +58,6 @@ def test_models(
numerical sensitive kernels.
"""
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
pytest.skip(
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
......
......@@ -141,7 +141,10 @@ class _HfExamplesInfo:
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"AquilaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/AquilaChat-7B"),
"ApertusForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"swiss-ai/Apertus-8B"),
min_transformers_version="4.56.0",
trust_remote_code=True),
"AquilaModel": _HfExamplesInfo(os.path.join(models_path_prefix,"BAAI/AquilaChat-7B"),
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/AquilaChat2-7B"),
trust_remote_code=True),
......@@ -219,10 +222,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"HunYuanDenseV1ForCausalLM":_HfExamplesInfo(os.path.join(models_path_prefix, "tencent/Hunyuan-7B-Instruct-0124"),
trust_remote_code=True,
is_available_online=False),
"HCXVisionForCausalLM": _HfExamplesInfo(
os.path.join(models_path_prefix, "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"),
trust_remote_code=True),
"InternLMForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "internlm/internlm-chat-7b"),
"InternLMForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"internlm/internlm-chat-7b"),
trust_remote_code=True),
"InternLM2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "internlm/internlm2-chat-7b"),
trust_remote_code=True),
......@@ -237,11 +237,13 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"tiny": os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"),
"random": os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-random"), # noqa: E501
}),
"LlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
extras={"guard": os.path.join(models_path_prefix, "meta-llama/Llama-Guard-3-1B"), # noqa: E501
"hermes": os.path.join(models_path_prefix, "NousResearch/Hermes-3-Llama-3.1-8B"), # noqa: E501
"fp8": os.path.join(models_path_prefix, "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8")}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "decapoda-research/llama-7b-hf"),
"Lfm2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"LiquidAI/LFM2-1.2B"),
min_transformers_version="4.54"),
"LlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"meta-llama/Llama-3.2-1B-Instruct"),
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"decapoda-research/llama-7b-hf"),
is_available_online=False),
"Llama4ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
is_available_online=False),
......@@ -293,18 +295,20 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen2-0.5B-Instruct"),
extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501
"Qwen2MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen1.5-MoE-A2.7B-Chat")),
"Qwen3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-8B")),
"Qwen3MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-30B-A3B")),
"RWForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tiiuae/falcon-40b")),
"SmolLM3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceTB/SmolLM3-3B")),
"StableLMEpochForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stabilityai/stablelm-zephyr-3b")), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stabilityai/stablelm-3b-4e1t")),
"Starcoder2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "bigcode/starcoder2-3b")),
"Step3TextForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"),
trust_remote_code=True,
is_available_online=False),
"SolarForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "upstage/solar-pro-preview-instruct"),
"Qwen2MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen1.5-MoE-A2.7B-Chat")),
"Qwen3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen3-8B")),
"Qwen3MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen3-30B-A3B")),
"RWForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"tiiuae/falcon-40b")),
"SeedOssForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"ByteDance-Seed/Seed-OSS-36B-Instruct"), # noqa: E501
trust_remote_code=True,
is_available_online=False),
"SmolLM3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"HuggingFaceTB/SmolLM3-3B")),
"StableLMEpochForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stabilityai/stablelm-zephyr-3b")), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stabilityai/stablelm-3b-4e1t")),
"Starcoder2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"bigcode/starcoder2-3b")),
"Step3TextForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stepfun-ai/step3"),
trust_remote_code=True),
"SolarForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"upstage/solar-pro-preview-instruct"),
trust_remote_code=True),
"TeleChat2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Tele-AI/TeleChat2-3B"),
trust_remote_code=True),
......@@ -330,8 +334,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5"), v0_only=True),
"Gemma2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2"), v0_only=True), # noqa: E501
"BertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5")),
"Gemma2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2")), # noqa: E501
"GritLM": _HfExamplesInfo(os.path.join(models_path_prefix, "parasail-ai/GritLM-7B-vllm")),
"GteModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Snowflake/snowflake-arctic-embed-m-v2.0"),
trust_remote_code=True),
......@@ -344,9 +348,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlamaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "llama"), is_available_online=False),
"MistralModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct")),
"ModernBertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-modernbert-base"),
trust_remote_code=True, v0_only=True),
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "nomic-ai/nomic-embed-text-v2-moe"),
trust_remote_code=True, v0_only=True), # noqa: E501
trust_remote_code=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "ssmits/Qwen2-7B-Instruct-embed-base")),
"Qwen2ForRewardModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"),
max_transformers_version="4.53",
......@@ -354,9 +358,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2ForProcessRewardModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-PRM-7B"),
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"RobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2"), v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/all-roberta-large-v1"), v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), v0_only=True), # noqa: E501
"RobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2")), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/all-roberta-large-v1")), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small")), # noqa: E501
# [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "royokong/e5-v")),
"Phi3VForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "TIGER-Lab/VLM2Vec-Full"),
......@@ -371,16 +375,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"GPT2ForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "nie3e/sentiment-polish-gpt2-small")), # noqa: E501
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), v0_only=True), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-reranker-modernbert-base"), v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/quora-roberta-base"), v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-m3"), v0_only=True), # noqa: E501
"BertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2")), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-multilingual-reranker-base"), # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": [os.path.join(models_path_prefix, "GteNewForSequenceClassification")]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-reranker-modernbert-base")), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/quora-roberta-base")), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-m3")), # noqa: E501
}
_AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-gemma"), # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
......@@ -403,6 +410,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"Emu3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/Emu3-Chat-hf")),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-VL-28B-A3B-PT"), # noqa: E501
trust_remote_code=True),
"FuyuForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "adept/fuyu-8b")),
"Gemma3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3-4b-it")),
"Gemma3nForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3n-E2B-it"), # noqa: E501
......@@ -413,23 +422,29 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")), # noqa: E501
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.5V"),
is_available_online=False), # noqa: E501
min_transformers_version="4.56"), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"),
trust_remote_code=True,
extras={"2b": os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-2b")}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"Idefics3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501
{"tiny": os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM-256M-Instruct")}, # noqa: E501
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
"InternVLChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"),
extras={"2B": os.path.join(models_path_prefix, "OpenGVLab/InternVL2-2B"),
"3.0": os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B")}, # noqa: E501
trust_remote_code=True),
"InternS1ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "internlm/Intern-S1"),
"HCXVisionForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"), # noqa: E501
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501
{"tiny": os.path.join(models_path_prefix,"HuggingFaceTB/SmolVLM-256M-Instruct")}, # noqa: E501
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55"), # noqa: E501
"InternS1ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"internlm/Intern-S1"),
trust_remote_code=True), # noqa: E501
"InternVLChatModel": _HfExamplesInfo(os.path.join(models_path_prefix,"OpenGVLab/InternVL2-1B"),
extras={"2B": os.path.join(models_path_prefix,"OpenGVLab/InternVL2-2B"),
"3.0": os.path.join(models_path_prefix,"OpenGVLab/InternVL3-1B"), # noqa: E501
"3.5-qwen3": os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-1B"), # noqa: E501
"3.5-qwen3moe": os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-30B-A3B"), # noqa: E501
"3.5-gptoss": os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview")}, # noqa: E501
trust_remote_code=True),
"KeyeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Kwai-Keye/Keye-VL-8B-Preview"), # noqa: E501
"InternVLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"OpenGVLab/InternVL3-1B-hf")), # noqa: E501
"KeyeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"Kwai-Keye/Keye-VL-8B-Preview"), # noqa: E501
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"), # noqa: E501
extras={"thinking": os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Thinking")}, # noqa: E501
......@@ -451,7 +466,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMO": _HfExamplesInfo(os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"),
trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo(os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"),
extras={"2.6": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"), "4.0": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-4")}, # noqa: E501
extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-VL-01"), # noqa: E501
trust_remote_code=True,
......@@ -473,6 +488,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible", # noqa: E501
extras={"1.6-llama": os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Llama3.2-3B"),
"1.6-gemma": os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Gemma2-9B")}), # noqa: E501
"Ovis2_5": _HfExamplesInfo(os.path.join(models_path_prefix, "AIDC-AI/Ovis2.5-2B"),
trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"), # noqa: E501
extras={"v2": os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448")}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "microsoft/Phi-3-vision-128k-instruct"),
......@@ -496,26 +513,30 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len=4096),
"Qwen2_5OmniModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-3B")),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-7B-AWQ")), # noqa: E501
"RForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "YannQi/R-4B"),
trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"),
trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55"), # noqa: E501
"Step3VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"),
trust_remote_code=True,
is_available_online=False),
trust_remote_code=True),
"UltravoxModel": _HfExamplesInfo(os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"), # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501
"Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501
"TarsierForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "omni-research/Tarsier-7b")), # noqa: E501
"Tarsier2ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "omni-research/Tarsier2-Recap-7b"), # noqa: E501
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Voxtral-Mini-3B-2507",
os.path.join(models_path_prefix, "mistralai/Voxtral-Mini-3B-2507"),
min_transformers_version="4.54",
# disable this temporarily until we support HF format
is_available_online=False,
),
# [Encoder-decoder]
"DonutForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "naver-clova-ix/donut-base-finetuned-docvqa"), # noqa: E501
hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501
extras={"dolphin": os.path.join(models_path_prefix, "ByteDance/Dolphin")}), # noqa: E501
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"microsoft/Florence-2-base"), # noqa: E501
......@@ -538,6 +559,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"DeepSeekMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "luccafong/deepseek_mtp_main_random"),
speculative_model=os.path.join(models_path_prefix, "luccafong/deepseek_mtp_draft_random"), # noqa: E501
trust_remote_code=True),
"EagleDeepSeekMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "eagle618/deepseek-v3-random"),
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
trust_remote_code=True),
"EagleLlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3-Instruct-8B"),
trust_remote_code=True,
speculative_model=os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3-Instruct-8B"),
......@@ -561,6 +585,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online=False,
speculative_model=os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16"),
tokenizer=os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16")),
"ErnieMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT"),
trust_remote_code=True,
speculative_model=os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT")),
"Glm4MoeMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.5"),
speculative_model=os.path.join(models_path_prefix, "zai-org/GLM-4.5"),
min_transformers_version="4.54",
......@@ -573,7 +600,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-Embedding-0.6B")),
"TransformersForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "hmellor/Ilama-3.2-1B"), trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B-hf")),
"TransformersForMultimodalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/Emu3-Chat-hf")),
}
_EXAMPLE_MODELS = {
......
......@@ -38,11 +38,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
model_arch=model_arch,
exist_overrides=model_info.hf_overrides)
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
......@@ -95,6 +90,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
if model_arch == "Lfm2ForCausalLM":
pytest.skip("Skipping until test supports V1-only models")
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
......
......@@ -27,6 +27,9 @@ models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_M
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Skip if transformers version is incompatible
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")
# Ensure all model classes can be imported successfully
model_cls = ModelRegistry._try_load_model_cls(model_arch)
assert model_cls is not None
......
......@@ -3,7 +3,8 @@
import warnings
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Union
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
......@@ -339,36 +340,43 @@ def softmax(data):
return F.softmax(data, dim=-1)
class EmbedModelInfo(NamedTuple):
@dataclass
class ModelInfo:
name: str
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
dtype: str = "auto"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class EmbedModelInfo(ModelInfo):
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
@dataclass
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST"
class RerankModelInfo(NamedTuple):
name: str
architecture: str = ""
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class RerankModelInfo(ModelInfo):
pass
@dataclass
class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal.cache import (MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
processor_cache_from_config,
receiver_cache_from_config)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField)
from vllm.multimodal.processing import PromptInsertion
from vllm.multimodal.registry import MultiModalRegistry
def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: Optional[np.random.RandomState] = None,
):
if rng is None:
data = torch.empty((size, ), dtype=torch.int8)
else:
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
data=data,
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
_dummy_elem(modality, key, size, rng=rng)
for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs([
_dummy_item(modality, size_by_key)
def _dummy_items(
size_by_key_modality: dict[str, dict[str, int]],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key, rng=rng)
for modality, size_by_key in size_by_key_modality.items()
])
......@@ -37,7 +69,8 @@ def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501
],
)
# yapf: enable
......@@ -47,5 +80,139 @@ def test_cache_item_size(item, expected_size):
cache[""] = item
assert cache.currsize == expected_size
cache[""] = MultiModalCacheItemMetadata.wraps(item)
prompt_update = PromptInsertion("dummy", "target", "insertion") \
.resolve(0)
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
assert cache.currsize == expected_size
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
assert cache.currsize == expected_size
def _create_vllm_config(
*,
mm_processor_cache_gb: float,
enable_ipc: bool,
):
return VllmConfig(
model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb),
parallel_config=ParallelConfig(
data_parallel_size=1 if enable_ipc else 2),
)
def _compare_caches(
config_0: VllmConfig,
config_1: VllmConfig,
*,
item_capacity: int = 8,
hit_rate: float = 0.5,
max_items_per_iter: int = 3,
is_cached_calls_per_iter: int,
n_iter: int = 100,
seed: int = 0,
):
mm_registry = MultiModalRegistry()
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
cache_size_gb = max(
config_0.model_config.mm_processor_cache_gb,
config_1.model_config.mm_processor_cache_gb,
)
item_size_gb = int(cache_size_gb / item_capacity)
rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
MultiModalHasher.hash_kwargs(item=item.get_data())
for item in all_items
]
# Should not be used since there is nothing to convert to text
prompt_update = PromptInsertion("dummy", "target", "insertion")
for it in range(n_iter):
num_items_to_select = rng.randint(0, max_items_per_iter)
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
selected_items = [all_items[idx] for idx in item_idxs_to_select]
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
if cache_0_p0 is None:
cache_0_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_0_p0.is_cached(selected_hashes)
cache_0_p0_out = [
item for item, _ in cache_0_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_1_p0 is None:
cache_1_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_1_p0.is_cached(selected_hashes)
cache_1_p0_out = [
item for item, _ in cache_1_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_0_p1 is None:
cache_0_p1_out = cache_0_p0_out
else:
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
selected_hashes)
if cache_1_p1 is None:
cache_1_p1_out = cache_1_p0_out
else:
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
selected_hashes)
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
cache_size_gb = 1 / (1 << 20)
vllm_config_ipc_enabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
vllm_config_ipc_disabled = _create_vllm_config(
mm_processor_cache_gb=0,
enable_ipc=False,
)
vllm_config_cache_disabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
_compare_caches(
vllm_config_ipc_enabled,
vllm_config_ipc_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_ipc_disabled,
vllm_config_cache_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_cache_disabled,
vllm_config_ipc_enabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
......@@ -45,10 +45,11 @@ def test_hash_collision_image_transpose():
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
def test_hash_collision_tensor_shape():
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_hash_collision_tensor_shape(dtype):
# The hash should be different though the data is the same when flattened
arr1 = torch.zeros((5, 10, 20, 3))
arr2 = torch.zeros((10, 20, 5, 3))
arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype)
arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype)
hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
......
......@@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement, apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
from .utils import random_image
......@@ -75,12 +73,15 @@ from .utils import random_image
),
],
)
@pytest.mark.parametrize("start_idx", [0, 4, 8])
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
result = list(iter_token_matches(token_ids, match_ids))
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
result = list(iter_token_matches(token_ids, match_ids,
start_idx=start_idx))
# Manually constructed results
assert [item._asdict() for item in result] == expected
assert [item._asdict() for item in result
] == [item for item in expected if item["start_idx"] >= start_idx]
# Invariants
match_lens = [end - start for start, end in result]
......@@ -241,21 +242,23 @@ def test_find_token_matches(
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
prompt_updates = {
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_updates)
}
result = {
key: list(update.iter_token_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error
print("result:", result)
# Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert {
key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, [])
for item in result.get(key, [])
]
for key in expected_by_key
} == expected_by_key
......@@ -388,21 +391,23 @@ def test_find_text_matches(
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
prompt_updates = {
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_updates)
}
result = {
key: list(update.iter_text_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error
print("result:", result)
# Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert {
key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, [])
for item in result.get(key, [])
]
for key in expected_by_key
} == expected_by_key
......@@ -552,39 +557,35 @@ def test_find_update_text(
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches(
mm_prompt_updates = {
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_text_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
mm_prompt_updates,
mock_tokenizer,
)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result)
# Manually constructed results
assert result == expected
assert new_prompt == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
# Tokenized test cases of `test_find_replace_text`
# Tokenized test cases of `test_find_update_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
......@@ -726,32 +727,28 @@ def test_find_update_tokens(
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches(
mm_prompt_updates = {
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_token_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
mm_prompt_updates,
mock_tokenizer,
)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result)
# Manually constructed results
assert result == expected
assert new_prompt == expected
# yapf: disable
......@@ -878,17 +875,11 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)]
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
for key in repl_by_key},
)
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
# Only displayed on error
print("result:", result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment