Unverified Commit 27e8d1ea authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 5c79b0d6
......@@ -77,6 +77,7 @@ Internal data structures.
- [vllm.multimodal.inputs.MultiModalFieldElem][]
- [vllm.multimodal.inputs.MultiModalFieldConfig][]
- [vllm.multimodal.inputs.MultiModalKwargsItem][]
- [vllm.multimodal.inputs.MultiModalKwargsItems][]
- [vllm.multimodal.inputs.MultiModalKwargs][]
- [vllm.multimodal.inputs.MultiModalInputs][]
......
......@@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......@@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
......
......@@ -370,10 +370,16 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
assert a_rest == b_rest, msg
a_data = a["mm_kwargs"].get_data()
b_data = b["mm_kwargs"].get_data()
for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)
a_data.pop(key, None)
b_data.pop(key, None)
assert a == b, msg
assert a_data == b_data, msg
......@@ -45,7 +45,8 @@ def test_processor_override(
video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token)
video_tok_count = processed_inputs["prompt_token_ids"].count(
video_token_id)
grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0]
grid_t, _, _ = processed_inputs["mm_kwargs"].get_data(
)["video_grid_thw"][0]
assert grid_t == expected_grid_t
assert video_tok_count == expected_toks_per_frame * grid_t
......@@ -108,7 +108,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches
......
......@@ -68,7 +68,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches
......
......@@ -51,14 +51,14 @@ def test_processor_override(
prompt = encode_tokens(tokenizer, prompt)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"]
mm_data = processed_inputs["mm_kwargs"].get_data()
# place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"]
aspect_ratios = mm_data["aspect_ratios"]
num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1:
......@@ -80,6 +80,6 @@ def test_processor_override(
num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum()
== sum(mm_data["patches_per_image"]) * num_patches_per_chunk
assert len(mm_data["pixel_values"]) \
== sum(mm_data["patches_per_image"])
......@@ -49,18 +49,18 @@ def test_profiling(
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs
mm_kwargs = processor.apply(
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
)["mm_kwargs"].get_data()
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
num_tiles = [[t] for t in mm_data.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
......
......@@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int):
hf_config = ctx.get_hf_config(Llama4Config)
mm_kwargs = processor.apply(
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
)["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"])
chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
0][1] # x-y seperator tokens
num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][
1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end
......
......@@ -70,7 +70,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches
......
......@@ -48,7 +48,8 @@ def test_processor_override(
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values"].shape
assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
......
......@@ -128,7 +128,7 @@ def create_batched_mm_kwargs(
)["mm_kwargs"]
items = [
item for modality in supported_mm_limits
for item in mm_kwargs.get_items(modality)
for item in mm_kwargs[modality]
]
return group_mm_kwargs_by_modality(items)
......
......@@ -4,8 +4,8 @@ import pytest
import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField)
......@@ -24,8 +24,8 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]):
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs([
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
......@@ -37,7 +37,8 @@ def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501
],
)
# yapf: enable
......
......@@ -11,7 +11,8 @@ import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalFlatField,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
......@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]
mm: Optional[list[MultiModalKwargsItems]]
def test_multimodal_kwargs():
......@@ -119,7 +120,7 @@ def test_multimodal_kwargs():
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs([audio, video, image])
mm = MultiModalKwargsItems.from_seq([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
......@@ -133,19 +134,22 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert 14250 <= total_len <= 14300
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# expected total encoding length, should be 14306, +-20 for minor changes
assert 14275 <= total_len <= 14325
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)
# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(decoded) == 3
images = decoded["image"]
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
mm_data = mm.get_data()
decoded_data = decoded.get_data()
assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
def nested_equal(a: NestedTensors, b: NestedTensors):
......
......@@ -4,11 +4,12 @@
from array import array
from typing import Any, Type
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
......@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any:
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
if isinstance(obj, MultiModalKwargs):
return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
......@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
if type is MultiModalKwargs:
return MultiModalKwargs(obj)
......@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
......@@ -470,7 +470,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......
......@@ -18,7 +18,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
......@@ -242,7 +242,7 @@ class AyaVisionMultiModalProcessor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
......
......@@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
......@@ -492,7 +492,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
......
......@@ -31,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
......@@ -151,7 +151,7 @@ class ChameleonMultiModalProcessor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
......
......@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
......@@ -241,7 +241,7 @@ class Cohere2VisionMultiModalProcessor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment