Unverified Commit 27e8d1ea authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 5c79b0d6
...@@ -77,6 +77,7 @@ Internal data structures. ...@@ -77,6 +77,7 @@ Internal data structures.
- [vllm.multimodal.inputs.MultiModalFieldElem][] - [vllm.multimodal.inputs.MultiModalFieldElem][]
- [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalFieldConfig][]
- [vllm.multimodal.inputs.MultiModalKwargsItem][] - [vllm.multimodal.inputs.MultiModalKwargsItem][]
- [vllm.multimodal.inputs.MultiModalKwargsItems][]
- [vllm.multimodal.inputs.MultiModalKwargs][] - [vllm.multimodal.inputs.MultiModalKwargs][]
- [vllm.multimodal.inputs.MultiModalInputs][] - [vllm.multimodal.inputs.MultiModalInputs][]
......
...@@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies ...@@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies ...@@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id bos_token_id = hf_config.bos_token_id
......
...@@ -370,10 +370,16 @@ def _assert_inputs_equal( ...@@ -370,10 +370,16 @@ def _assert_inputs_equal(
if ignore_mm_keys is None: if ignore_mm_keys is None:
ignore_mm_keys = set() ignore_mm_keys = set()
assert "mm_kwargs" in a and "mm_kwargs" in b, msg a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
assert a_rest == b_rest, msg
a_data = a["mm_kwargs"].get_data()
b_data = b["mm_kwargs"].get_data()
for key in ignore_mm_keys: for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None) a_data.pop(key, None)
b["mm_kwargs"].pop(key, None) b_data.pop(key, None)
assert a == b, msg assert a_data == b_data, msg
...@@ -45,7 +45,8 @@ def test_processor_override( ...@@ -45,7 +45,8 @@ def test_processor_override(
video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token) video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token)
video_tok_count = processed_inputs["prompt_token_ids"].count( video_tok_count = processed_inputs["prompt_token_ids"].count(
video_token_id) video_token_id)
grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0] grid_t, _, _ = processed_inputs["mm_kwargs"].get_data(
)["video_grid_thw"][0]
assert grid_t == expected_grid_t assert grid_t == expected_grid_t
assert video_tok_count == expected_toks_per_frame * grid_t assert video_tok_count == expected_toks_per_frame * grid_t
...@@ -108,7 +108,8 @@ def _run_check( ...@@ -108,7 +108,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches
......
...@@ -68,7 +68,8 @@ def _run_check( ...@@ -68,7 +68,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches
......
...@@ -51,14 +51,14 @@ def test_processor_override( ...@@ -51,14 +51,14 @@ def test_processor_override(
prompt = encode_tokens(tokenizer, prompt) prompt = encode_tokens(tokenizer, prompt)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"] mm_data = processed_inputs["mm_kwargs"].get_data()
# place holder replacements # place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"] aspect_ratios = mm_data["aspect_ratios"]
num_x_separators = num_y_separators = 0 num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios: for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1: if tiles_x * tiles_y > 1:
...@@ -80,6 +80,6 @@ def test_processor_override( ...@@ -80,6 +80,6 @@ def test_processor_override(
num_patches_per_chunk = processor.info.get_patch_per_chunk( num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config) config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \ assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk == sum(mm_data["patches_per_image"]) * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \ assert len(mm_data["pixel_values"]) \
== mm_kwargs["patches_per_image"].sum() == sum(mm_data["patches_per_image"])
...@@ -49,18 +49,18 @@ def test_profiling( ...@@ -49,18 +49,18 @@ def test_profiling(
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs ] * max_num_seqs
mm_kwargs = processor.apply( mm_data = processor.apply(
prompt=dummy_mm_data.prompt, prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data, mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(), hf_processor_mm_kwargs=dict(),
)["mm_kwargs"] )["mm_kwargs"].get_data()
# Get the actual number of encoder tokens for each sample. # Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last # Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the # group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only. # block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details. # See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")] num_tiles = [[t] for t in mm_data.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size) num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [ actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
......
...@@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int): ...@@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int):
hf_config = ctx.get_hf_config(Llama4Config) hf_config = ctx.get_hf_config(Llama4Config)
mm_kwargs = processor.apply( mm_data = processor.apply(
prompt=dummy_mm_data.prompt, prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data, mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(), hf_processor_mm_kwargs=dict(),
)["mm_kwargs"] )["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size patch_size = hf_config.vision_config.patch_size
downsample_ratio = int( downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"]) chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][ num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][
0][1] # x-y seperator tokens 1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item( total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end ) + 3 # image start, image, image end
......
...@@ -70,7 +70,8 @@ def _run_check( ...@@ -70,7 +70,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<image>") image_token_id = tokenizer.convert_tokens_to_ids("<image>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
assert img_tok_count == 256 * total_expected_num_patches assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches
......
...@@ -48,7 +48,8 @@ def test_processor_override( ...@@ -48,7 +48,8 @@ def test_processor_override(
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values"].shape
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
......
...@@ -128,7 +128,7 @@ def create_batched_mm_kwargs( ...@@ -128,7 +128,7 @@ def create_batched_mm_kwargs(
)["mm_kwargs"] )["mm_kwargs"]
items = [ items = [
item for modality in supported_mm_limits item for modality in supported_mm_limits
for item in mm_kwargs.get_items(modality) for item in mm_kwargs[modality]
] ]
return group_mm_kwargs_by_modality(items) return group_mm_kwargs_by_modality(items)
......
...@@ -4,8 +4,8 @@ import pytest ...@@ -4,8 +4,8 @@ import pytest
import torch import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItem, MultiModalKwargsItems,
MultiModalSharedField) MultiModalSharedField)
...@@ -24,8 +24,8 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]): ...@@ -24,8 +24,8 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]):
]) ])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs([ return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key) _dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items() for modality, size_by_key in size_by_key_modality.items()
]) ])
...@@ -37,7 +37,8 @@ def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): ...@@ -37,7 +37,8 @@ def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
[ [
(_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210), (_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501
], ],
) )
# yapf: enable # yapf: enable
......
...@@ -11,7 +11,8 @@ import torch ...@@ -11,7 +11,8 @@ import torch
from vllm.multimodal.inputs import (MultiModalBatchedField, from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalFlatField, MultiModalFieldElem, MultiModalFlatField,
MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors) MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
...@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): ...@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
class MyRequest(msgspec.Struct): class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]] mm: Optional[list[MultiModalKwargsItems]]
def test_multimodal_kwargs(): def test_multimodal_kwargs():
...@@ -119,7 +120,7 @@ def test_multimodal_kwargs(): ...@@ -119,7 +120,7 @@ def test_multimodal_kwargs():
audio = MultiModalKwargsItem.from_elems([e1]) audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2]) video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4]) image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs([audio, video, image]) mm = MultiModalKwargsItems.from_seq([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly # pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm]) req = MyRequest([mm])
...@@ -133,19 +134,22 @@ def test_multimodal_kwargs(): ...@@ -133,19 +134,22 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes # expected total encoding length, should be 14306, +-20 for minor changes
assert 14250 <= total_len <= 14300 assert 14275 <= total_len <= 14325
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)
# check all modalities were recovered and do some basic sanity checks # check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3 assert len(decoded) == 3
images = decoded.get_items("image") images = decoded["image"]
assert len(images) == 1 assert len(images) == 1
assert len(images[0].items()) == 2 assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"] assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict # check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm) mm_data = mm.get_data()
decoded_data = decoded.get_data()
assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
def nested_equal(a: NestedTensors, b: NestedTensors): def nested_equal(a: NestedTensors, b: NestedTensors):
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
from array import array from array import array
from typing import Any, Type from typing import Any, Type
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any: def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types. """Custom msgspec enc hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
""" """
...@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any: ...@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any:
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.") f"Given array has a type code of {obj.typecode}.")
return obj.tobytes() return obj.tobytes()
if isinstance(obj, MultiModalKwargs):
return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any: def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types. """Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
""" """
...@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any: ...@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj) deserialized.frombytes(obj)
return deserialized return deserialized
if type is MultiModalKwargs:
return MultiModalKwargs(obj)
...@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs) MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -470,7 +470,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): ...@@ -470,7 +470,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
......
...@@ -18,7 +18,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( ...@@ -18,7 +18,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -242,7 +242,7 @@ class AyaVisionMultiModalProcessor( ...@@ -242,7 +242,7 @@ class AyaVisionMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token image_token = hf_processor.image_token
......
...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs) MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
...@@ -492,7 +492,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ...@@ -492,7 +492,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
......
...@@ -31,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -31,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs) MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -151,7 +151,7 @@ class ChameleonMultiModalProcessor( ...@@ -151,7 +151,7 @@ class ChameleonMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -241,7 +241,7 @@ class Cohere2VisionMultiModalProcessor( ...@@ -241,7 +241,7 @@ class Cohere2VisionMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token image_token = hf_processor.image_token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment