Unverified Commit c6e7404c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Multimodal] Simplify MM input definitions (#33331)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 17b17c06
...@@ -23,18 +23,16 @@ from vllm.multimodal.inputs import ( ...@@ -23,18 +23,16 @@ from vllm.multimodal.inputs import (
) )
def _dummy_elem(modality: str, key: str, size: int): def _dummy_elem(size: int):
return MultiModalFieldElem( return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size,), dtype=torch.int8), data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(batch_size=1), field=MultiModalSharedField(batch_size=1),
) )
def _dummy_item(modality: str, size_by_key: dict[str, int]): def _dummy_item(size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems( return MultiModalKwargsItem(
[_dummy_elem(modality, key, size) for key, size in size_by_key.items()] {key: _dummy_elem(size) for key, size in size_by_key.items()}
) )
...@@ -61,7 +59,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase): ...@@ -61,7 +59,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
def test_minimal_put_get_cycle(self): def test_minimal_put_get_cycle(self):
"""Test basic put and get operations.""" """Test basic put and get operations."""
key = "test_key" key = "test_key"
value = _dummy_item("text", {"field1": 10, "field2": 20}) value = _dummy_item({"field1": 10, "field2": 20})
# Put operation # Put operation
address, monotonic_id = self.storage.put(key, value) address, monotonic_id = self.storage.put(key, value)
......
...@@ -119,7 +119,11 @@ def create_batched_mm_kwargs( ...@@ -119,7 +119,11 @@ def create_batched_mm_kwargs(
)["mm_kwargs"].require_data() )["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality( return group_mm_kwargs_by_modality(
[item for modality in supported_mm_limits for item in mm_kwargs[modality]] [
(modality, item)
for modality in supported_mm_limits
for item in mm_kwargs[modality]
]
) )
......
...@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test ...@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
def _dummy_elem( def _dummy_elem(
modality: str,
key: str,
size: int, size: int,
*, *,
rng: np.random.RandomState | None = None, rng: np.random.RandomState | None = None,
...@@ -48,21 +46,18 @@ def _dummy_elem( ...@@ -48,21 +46,18 @@ def _dummy_elem(
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8)) data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))
return MultiModalFieldElem( return MultiModalFieldElem(
modality=modality,
key=key,
data=data, data=data,
field=MultiModalSharedField(batch_size=1), field=MultiModalSharedField(batch_size=1),
) )
def _dummy_item( def _dummy_item(
modality: str,
size_by_key: dict[str, int], size_by_key: dict[str, int],
*, *,
rng: np.random.RandomState | None = None, rng: np.random.RandomState | None = None,
): ):
return MultiModalKwargsItem.from_elems( return MultiModalKwargsItem(
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()] {key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()}
) )
...@@ -71,19 +66,19 @@ def _dummy_items( ...@@ -71,19 +66,19 @@ def _dummy_items(
*, *,
rng: np.random.RandomState | None = None, rng: np.random.RandomState | None = None,
): ):
return MultiModalKwargsItems.from_seq( return MultiModalKwargsItems(
[ {
_dummy_item(modality, size_by_key, rng=rng) modality: [_dummy_item(size_by_key, rng=rng)]
for modality, size_by_key in size_by_key_modality.items() for modality, size_by_key in size_by_key_modality.items()
] }
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("item", "expected_size"), ("item", "expected_size"),
[ [
(_dummy_item("a", {"a1": 100}), 100), (_dummy_item({"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210), (_dummy_item({"a1": 100, "a2": 110}), 210),
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
], ],
) )
...@@ -143,7 +138,7 @@ def _compare_caches( ...@@ -143,7 +138,7 @@ def _compare_caches(
rng = np.random.RandomState(seed) rng = np.random.RandomState(seed)
all_items = [ all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng) _dummy_item({"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate)) for _ in range(int(item_capacity / hit_rate))
] ]
all_hashes = [ all_hashes = [
...@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru( ...@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
"image_C", "image_C",
] ]
request1_items = { request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size) h: MultiModalKwargsItem.dummy(nbytes=2 * base_item_size)
for h in request1_hashes for h in request1_hashes
} }
request2_hashes = ["image_D", "image_E", "image_A", "image_C"] request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
request2_items = { request2_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size) h: MultiModalKwargsItem.dummy(nbytes=1 * base_item_size)
for h in request2_hashes for h in request2_hashes
} }
...@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm( ...@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
): ):
request1_hashes = ["image_A", "image_B", "image_C"] request1_hashes = ["image_A", "image_B", "image_C"]
request1_items = { request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size) h: MultiModalKwargsItem.dummy(5 * base_item_size) for h in request1_hashes
for h in request1_hashes
} }
request1_items_p0_result = [] request1_items_p0_result = []
request2_hashes = ["image_G", "image_A"] request2_hashes = ["image_G", "image_A"]
request2_items = { request2_items = {
h: MultiModalKwargsItem.dummy( h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size (5 if h in request1_hashes else 2) * base_item_size
) )
for h in request2_hashes for h in request2_hashes
} }
...@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm( ...@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
request3_hashes = ["image_G", "image_H", "image_I", "image_B"] request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
request3_items = { request3_items = {
h: MultiModalKwargsItem.dummy( h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size (5 if h in request1_hashes else 2) * base_item_size
) )
for h in request3_hashes for h in request3_hashes
} }
...@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras(): ...@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
lora_a_identifier = f"12345:{base_mm_hash}" lora_a_identifier = f"12345:{base_mm_hash}"
lora_b_identifier = f"67890:{base_mm_hash}" lora_b_identifier = f"67890:{base_mm_hash}"
item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024) item_data = MultiModalKwargsItem.dummy(1024)
feature_lora_a = MultiModalFeatureSpec( feature_lora_a = MultiModalFeatureSpec(
data=item_data, data=item_data,
......
...@@ -77,7 +77,7 @@ def make_request( ...@@ -77,7 +77,7 @@ def make_request(
for j, position in enumerate(mm_positions): for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"), data=MultiModalKwargsItem.dummy(),
mm_position=position, mm_position=position,
identifier=identifier, identifier=identifier,
modality="image", modality="image",
......
...@@ -68,7 +68,7 @@ def make_request( ...@@ -68,7 +68,7 @@ def make_request(
for j, position in enumerate(mm_positions): for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"), data=MultiModalKwargsItem.dummy(),
mm_position=position, mm_position=position,
identifier=identifier, identifier=identifier,
modality="image", modality="image",
......
...@@ -56,7 +56,7 @@ def _create_random_request( ...@@ -56,7 +56,7 @@ def _create_random_request(
for j, position in enumerate(mm_positions): for j, position in enumerate(mm_positions):
identifier = f"{request_id}_hash_{j}" identifier = f"{request_id}_hash_{j}"
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"), data=MultiModalKwargsItem.dummy(),
mm_position=position, mm_position=position,
identifier=identifier, identifier=identifier,
modality="image", modality="image",
......
...@@ -1707,7 +1707,7 @@ def create_requests_with_priority( ...@@ -1707,7 +1707,7 @@ def create_requests_with_priority(
# Unique dummy hash for each mm item # Unique dummy hash for each mm item
identifier = f"hash{i}_{j}" identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"), data=MultiModalKwargsItem.dummy(),
mm_position=position, mm_position=position,
identifier=identifier, identifier=identifier,
modality="image", modality="image",
......
...@@ -236,7 +236,7 @@ def create_requests( ...@@ -236,7 +236,7 @@ def create_requests(
# Unique dummy hash for each mm item # Unique dummy hash for each mm item
identifier = f"hash{i}_{j}" identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"), data=MultiModalKwargsItem.dummy(),
mm_position=position, mm_position=position,
identifier=identifier, identifier=identifier,
modality="image", modality="image",
......
...@@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat ...@@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# Step 1: Create initial request state with one multimodal feature # Step 1: Create initial request state with one multimodal feature
mm_feature_1 = MultiModalFeatureSpec( mm_feature_1 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"), data=MultiModalKwargsItem.dummy(),
modality="audio", modality="audio",
identifier="audio_1", identifier="audio_1",
mm_position=PlaceholderRange(offset=2, length=10), mm_position=PlaceholderRange(offset=2, length=10),
...@@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat ...@@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# The scheduler has already set prompt_token_ids to the full sequence # The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt with new multimodal feature) # (original prompt + intermediate outputs + new prompt with new multimodal feature)
mm_feature_2 = MultiModalFeatureSpec( mm_feature_2 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"), data=MultiModalKwargsItem.dummy(),
modality="audio", modality="audio",
identifier="audio_2", identifier="audio_2",
mm_position=PlaceholderRange(offset=15, length=5), mm_position=PlaceholderRange(offset=15, length=5),
......
...@@ -174,7 +174,7 @@ class TestStreamingScheduler(unittest.TestCase): ...@@ -174,7 +174,7 @@ class TestStreamingScheduler(unittest.TestCase):
scheduler = create_scheduler() scheduler = create_scheduler()
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"), data=MultiModalKwargsItem.dummy(),
modality="audio", modality="audio",
identifier="", identifier="",
mm_position=PlaceholderRange(offset=1, length=1), mm_position=PlaceholderRange(offset=1, length=1),
...@@ -187,7 +187,7 @@ class TestStreamingScheduler(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestStreamingScheduler(unittest.TestCase):
session.num_computed_tokens = len(session.prompt_token_ids) session.num_computed_tokens = len(session.prompt_token_ids)
mm_feature = MultiModalFeatureSpec( mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"), data=MultiModalKwargsItem.dummy(),
modality="audio", modality="audio",
identifier="", identifier="",
mm_position=PlaceholderRange(offset=2, length=1), mm_position=PlaceholderRange(offset=2, length=1),
......
...@@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct): ...@@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs(): def test_multimodal_kwargs():
e1 = MultiModalFieldElem( e1 = MultiModalFieldElem(
"audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16), torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(), MultiModalBatchedField(),
) )
e2 = MultiModalFieldElem( e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)], [torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField( MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
...@@ -119,21 +115,20 @@ def test_multimodal_kwargs(): ...@@ -119,21 +115,20 @@ def test_multimodal_kwargs():
), ),
) )
e3 = MultiModalFieldElem( e3 = MultiModalFieldElem(
"image",
"i0",
torch.zeros(1000, dtype=torch.int32), torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4), MultiModalSharedField(batch_size=4),
) )
e4 = MultiModalFieldElem( e4 = MultiModalFieldElem(
"image",
"i1",
torch.zeros(1000, dtype=torch.int32), torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2), MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
) )
audio = MultiModalKwargsItem.from_elems([e1]) mm = MultiModalKwargsItems(
video = MultiModalKwargsItem.from_elems([e2]) {
image = MultiModalKwargsItem.from_elems([e3, e4]) "audio": [MultiModalKwargsItem({"a0": e1})],
mm = MultiModalKwargsItems.from_seq([audio, video, image]) "video": [MultiModalKwargsItem({"v0": e2})],
"image": [MultiModalKwargsItem({"i0": e3, "i1": e4})],
}
)
# pack mm kwargs into a mock request so that it can be decoded properly # pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm]) req = MyRequest([mm])
...@@ -147,8 +142,8 @@ def test_multimodal_kwargs(): ...@@ -147,8 +142,8 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14395, +-20 for minor changes # expected total encoding length, should be 14319, +-20 for minor changes
assert 14375 <= total_len <= 14425 assert 14300 <= total_len <= 14340
decoded = decoder.decode(encoded).mm[0] decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems) assert isinstance(decoded, MultiModalKwargsItems)
......
...@@ -463,8 +463,8 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): ...@@ -463,8 +463,8 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
ring_buffer=ring_buffer, ring_buffer=ring_buffer,
serde_class=MsgpackSerde, serde_class=MsgpackSerde,
) )
# cache (prompt_updates, modality) for P0 only # cache prompt_updates for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} self._p0_cache: dict[str, Sequence[ResolvedPromptUpdate]] = {}
self._hits = 0 self._hits = 0
self._total = 0 self._total = 0
...@@ -495,23 +495,22 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): ...@@ -495,23 +495,22 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self._total += 1 self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash) address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash] prompt_updates = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates return self.address_as_item(address, monotonic_id), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}" assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
item, prompt_updates = mm_item
self._total += 1 self._total += 1
try: try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) address, monotonic_id = self._shm_cache.put(mm_hash, item)
# Try to remove dangling items if p0 cache is too large. # Try to remove dangling items if p0 cache is too large.
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
self.remove_dangling_items() self.remove_dangling_items()
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
address_item = self.address_as_item( self._p0_cache[mm_hash] = prompt_updates
address, monotonic_id, mm_item[0].modality return self.address_as_item(address, monotonic_id), prompt_updates
)
return address_item, mm_item[1]
except (ValueError, MemoryError) as e: except (ValueError, MemoryError) as e:
# put may fail if the object is too large or # put may fail if the object is too large or
# the cache is full. # the cache is full.
...@@ -550,22 +549,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): ...@@ -550,22 +549,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
del self._p0_cache[mm_hash] del self._p0_cache[mm_hash]
def address_as_item( def address_as_item(
self, address: int, monotonic_id: int, modality: str self,
address: int,
monotonic_id: int,
) -> MultiModalKwargsItem: ) -> MultiModalKwargsItem:
addr_elem = MultiModalFieldElem( addr_elem = MultiModalFieldElem(
modality=modality,
key="address",
data=address, data=address,
field=MultiModalBatchedField(), field=MultiModalBatchedField(),
) )
id_elem = MultiModalFieldElem( id_elem = MultiModalFieldElem(
modality=modality,
key="monotonic_id",
data=monotonic_id, data=monotonic_id,
field=MultiModalBatchedField(), field=MultiModalBatchedField(),
) )
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
return mm_item return MultiModalKwargsItem({"address": addr_elem, "monotonic_id": id_elem})
class BaseMultiModalReceiverCache( class BaseMultiModalReceiverCache(
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from PIL.Image import Image from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar from typing_extensions import NotRequired, TypeVar
from vllm.utils.collection_utils import full_groupby, is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
...@@ -336,25 +336,33 @@ class MultiModalFeatureSpec: ...@@ -336,25 +336,33 @@ class MultiModalFeatureSpec:
""" """
Represents a single multimodal input with its processed data and metadata. Represents a single multimodal input with its processed data and metadata.
Used by the V1 engine to track multimodal data through processing and Used to track multimodal data through processing and caching.
caching. A request containing multiple multimodal items will have one A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item. `MultiModalFeatureSpec` per item.
""" """
data: Optional["MultiModalKwargsItem"] data: Optional["MultiModalKwargsItem"]
"""Multimodal data for this feature""" """
Represents multimodal data for this feature.
Can be `None` if the item is cached, to skip IPC between API server
and engine core processes.
"""
modality: str modality: str
"""Based on the input, e.g., "image", "audio", "video".""" """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
identifier: str identifier: str
"""mm_hash or uuid for caching encoder outputs.""" """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
mm_position: PlaceholderRange mm_position: PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)""" """
The location of the `modality` tokens corresponding to this item
in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
"""
mm_hash: str | None = None mm_hash: str | None = None
"""Base mm_hash for processor cache (without LoRA prefix).""" """The hash for caching processor outputs (without LoRA prefix)."""
@staticmethod @staticmethod
def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
...@@ -373,23 +381,10 @@ class MultiModalFeatureSpec: ...@@ -373,23 +381,10 @@ class MultiModalFeatureSpec:
@dataclass @dataclass
class MultiModalFieldElem: class MultiModalFieldElem:
""" """
Represents a keyword argument inside a Represents a processed keyword argument to pass to a model for a
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]. [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
""" """
modality: str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key: str
"""
The key of this field in
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors data: NestedTensors
""" """
The tensor data of this field in The tensor data of this field in
...@@ -417,11 +412,7 @@ class MultiModalFieldElem: ...@@ -417,11 +412,7 @@ class MultiModalFieldElem:
else: else:
data_equal = nested_tensors_equal(self.data, other.data) data_equal = nested_tensors_equal(self.data, other.data)
return ( return data_equal and type(self.field) is type(other.field) # noqa: E721
(self.modality, self.key) == (other.modality, other.key)
and data_equal
and type(self.field) is type(other.field)
) # noqa: E721
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
...@@ -438,13 +429,8 @@ class BaseMultiModalField(ABC): ...@@ -438,13 +429,8 @@ class BaseMultiModalField(ABC):
when `MultiModalKwargsItems.get_data()` is called to batch the data. when `MultiModalKwargsItems.get_data()` is called to batch the data.
""" """
def _field_factory(self, *, modality: str, key: str): def _field_factory(self):
f = partial( f = partial(MultiModalFieldElem, field=self)
MultiModalFieldElem,
modality=modality,
key=key,
field=self,
)
# Allow passing data as positional argument # Allow passing data as positional argument
def factory(data: NestedTensors) -> MultiModalFieldElem: def factory(data: NestedTensors) -> MultiModalFieldElem:
...@@ -519,7 +505,7 @@ class MultiModalBatchedField(BaseMultiModalField): ...@@ -519,7 +505,7 @@ class MultiModalBatchedField(BaseMultiModalField):
key: str, key: str,
data: NestedTensors, data: NestedTensors,
) -> Sequence[MultiModalFieldElem]: ) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory()
return [field_factory(item) for item in data] return [field_factory(item) for item in data]
def _reduce_data( def _reduce_data(
...@@ -565,7 +551,7 @@ class MultiModalFlatField(BaseMultiModalField): ...@@ -565,7 +551,7 @@ class MultiModalFlatField(BaseMultiModalField):
key: str, key: str,
data: NestedTensors, data: NestedTensors,
) -> Sequence[MultiModalFieldElem]: ) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory()
if not is_list_of(self.slices, slice, check="all"): if not is_list_of(self.slices, slice, check="all"):
assert isinstance(data, torch.Tensor), ( assert isinstance(data, torch.Tensor), (
"torch.Tensor is required for multiple slices" "torch.Tensor is required for multiple slices"
...@@ -664,7 +650,7 @@ class MultiModalSharedField(BaseMultiModalField): ...@@ -664,7 +650,7 @@ class MultiModalSharedField(BaseMultiModalField):
key: str, key: str,
data: NestedTensors, data: NestedTensors,
) -> Sequence[MultiModalFieldElem]: ) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory()
return [field_factory(data)] * self.batch_size return [field_factory(data)] * self.batch_size
def _reduce_data( def _reduce_data(
...@@ -899,37 +885,19 @@ class MultiModalFieldConfig: ...@@ -899,37 +885,19 @@ class MultiModalFieldConfig:
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
""" """
A collection of A dictionary of processed keyword arguments to pass to the model,
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem] corresponding to a single item in
corresponding to a data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
""" """
@staticmethod @staticmethod
def dummy(modality: str, nbytes: int = 1): def dummy(nbytes: int = 1):
"""Convenience class for testing.""" """Convenience class for testing."""
mm_elem = MultiModalFieldElem( mm_elem = MultiModalFieldElem(
modality=modality,
key="dummy",
data=torch.empty(nbytes, dtype=torch.uint8), data=torch.empty(nbytes, dtype=torch.uint8),
field=MultiModalSharedField(batch_size=1), field=MultiModalSharedField(batch_size=1),
) )
return MultiModalKwargsItem.from_elems([mm_elem]) return MultiModalKwargsItem({"dummy": mm_elem})
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems})
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
super().__init__(data)
modalities = {elem.modality for elem in self.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
self._modality = next(iter(modalities))
@property
def modality(self) -> str:
return self._modality
def get_data(self) -> dict[str, NestedTensors]: def get_data(self) -> dict[str, NestedTensors]:
return {key: elem.data for key, elem in self.items()} return {key: elem.data for key, elem in self.items()}
...@@ -945,9 +913,38 @@ _I = TypeVar( ...@@ -945,9 +913,38 @@ _I = TypeVar(
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
""" """
A dictionary of A dictionary of processed multi-modal inputs by modality.
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
by modality. For example, given a processor that processes
images into `pixel_values` and `image_grid_thw`,
and audios into `input_audio_features`,
a prompt with 2 images and 1 audio will be processed
into a `MultiModalKwargsItems` with the following structure:
```python
MultiModalKwargsItems(
{
"image": [
# For the first image
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
# For the second imgae
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
],
"audio": [
# For the first audio
MultiModalKwargsItem({"input_audio_features": ...}),
],
}
)
```
Unlike HF processing which returns all items
in a single dictionary with batched keyword arguments,
we split up the items because some of them may already be cached.
Also, items from multiple requests may be batched together to improve throughput,
using the logic defined by the
[`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
for each keyword argument.
""" """
@staticmethod @staticmethod
...@@ -967,7 +964,7 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): ...@@ -967,7 +964,7 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
elems_by_key[key] = elems elems_by_key[key] = elems
keys_by_modality[config.modality].add(key) keys_by_modality[config.modality].add(key)
items = list[MultiModalKwargsItem]() items_by_modality = dict[str, list[MultiModalKwargsItem]]()
for modality, keys in keys_by_modality.items(): for modality, keys in keys_by_modality.items():
elems_in_modality = {k: elems_by_key[k] for k in keys} elems_in_modality = {k: elems_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()} batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
...@@ -979,15 +976,11 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): ...@@ -979,15 +976,11 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
) )
batch_size = next(iter(batch_sizes.values())) batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size): items_by_modality[modality] = [
elems = [v[item_idx] for v in elems_in_modality.values()] MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
items.append(MultiModalKwargsItem.from_elems(elems)) for i in range(batch_size)
]
return MultiModalKwargsItems.from_seq(items)
@staticmethod
def from_seq(items: Sequence[MultiModalKwargsItem]):
items_by_modality = full_groupby(items, key=lambda x: x.modality)
return MultiModalKwargsItems(items_by_modality) return MultiModalKwargsItems(items_by_modality)
def __getitem__(self, modality: str) -> Sequence[_I]: def __getitem__(self, modality: str) -> Sequence[_I]:
......
...@@ -467,7 +467,7 @@ def argsort_mm_positions( ...@@ -467,7 +467,7 @@ def argsort_mm_positions(
def group_mm_kwargs_by_modality( def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem], mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*, *,
device: torch.types.Device = None, device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
...@@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality( ...@@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality(
""" """
from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
items_lst = list(items) items_lst = [item for _, item in group]
mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst) mm_kwargs_items = MultiModalKwargsItems({modality: items_lst})
mm_kwargs_data = mm_kwargs_items.get_data( mm_kwargs_data = mm_kwargs_items.get_data(
device=device, device=device,
pin_memory=pin_memory, pin_memory=pin_memory,
......
...@@ -242,13 +242,11 @@ class MsgpackEncoder: ...@@ -242,13 +242,11 @@ class MsgpackEncoder:
for modality, itemlist in items.items() for modality, itemlist in items.items()
} }
def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: def _encode_mm_item(self, item: MultiModalKwargsItem) -> dict[str, Any]:
return [self._encode_mm_field_elem(elem) for elem in item.values()] return {key: self._encode_mm_field_elem(elem) for key, elem in item.items()}
def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]: def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
return { return {
"modality": elem.modality,
"key": elem.key,
"data": ( "data": (
None if elem.data is None else self._encode_nested_tensors(elem.data) None if elem.data is None else self._encode_nested_tensors(elem.data)
), ),
...@@ -383,9 +381,9 @@ class MsgpackDecoder: ...@@ -383,9 +381,9 @@ class MsgpackDecoder:
} }
) )
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: def _decode_mm_item(self, obj: dict[str, Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems( return MultiModalKwargsItem(
[self._decode_mm_field_elem(v) for v in obj] {key: self._decode_mm_field_elem(elem) for key, elem in obj.items()}
) )
def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem: def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
......
...@@ -43,9 +43,9 @@ class EncoderRunner: ...@@ -43,9 +43,9 @@ class EncoderRunner:
def prepare_mm_inputs( def prepare_mm_inputs(
self, self,
scheduled_encoder_inputs: dict[str, list[int]], scheduled_encoder_inputs: dict[str, list[int]],
) -> tuple[list[str], list[MultiModalKwargsItem]]: ) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
mm_hashes: list[str] = [] mm_hashes: list[str] = []
mm_kwargs: list[MultiModalKwargsItem] = [] mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id] mm_features = self.req_id_to_mm_features[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
...@@ -53,7 +53,8 @@ class EncoderRunner: ...@@ -53,7 +53,8 @@ class EncoderRunner:
if mm_feature.data is None: if mm_feature.data is None:
continue continue
mm_hashes.append(mm_feature.identifier) mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data) mm_kwargs.append((mm_feature.modality, mm_feature.data))
return mm_hashes, mm_kwargs return mm_hashes, mm_kwargs
@torch.inference_mode() @torch.inference_mode()
...@@ -61,7 +62,7 @@ class EncoderRunner: ...@@ -61,7 +62,7 @@ class EncoderRunner:
self, self,
model: SupportsMultiModal, model: SupportsMultiModal,
mm_hashes: list[str], mm_hashes: list[str],
mm_kwargs: list[MultiModalKwargsItem], mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
if not mm_hashes: if not mm_hashes:
return [] return []
......
...@@ -1217,11 +1217,11 @@ class GPUModelRunner( ...@@ -1217,11 +1217,11 @@ class GPUModelRunner(
if not scheduler_output or not self.is_multimodal_raw_input_only_model: if not scheduler_output or not self.is_multimodal_raw_input_only_model:
return {} return {}
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
for req in scheduler_output.scheduled_new_reqs: for req in scheduler_output.scheduled_new_reqs:
for feature in req.mm_features: for feature in req.mm_features:
if feature.data is not None: if feature.data is not None:
mm_kwargs.append(feature.data) mm_kwargs.append((feature.modality, feature.data))
# Input all modalities at once # Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {} mm_kwargs_combined: BatchedTensorInputs = {}
...@@ -2219,7 +2219,7 @@ class GPUModelRunner( ...@@ -2219,7 +2219,7 @@ class GPUModelRunner(
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[ ) -> tuple[
list[str], list[str],
list[MultiModalKwargsItem], list[tuple[str, MultiModalKwargsItem]],
list[tuple[str, PlaceholderRange]], list[tuple[str, PlaceholderRange]],
]: ]:
"""Batch multimodal inputs from scheduled encoder inputs. """Batch multimodal inputs from scheduled encoder inputs.
...@@ -2239,7 +2239,7 @@ class GPUModelRunner( ...@@ -2239,7 +2239,7 @@ class GPUModelRunner(
return [], [], [] return [], [], []
mm_hashes = list[str]() mm_hashes = list[str]()
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
# Multimodal LoRA reference info to map each multimodal item # Multimodal LoRA reference info to map each multimodal item
# back to its request & position # back to its request & position
mm_lora_refs = list[tuple[str, PlaceholderRange]]() mm_lora_refs = list[tuple[str, PlaceholderRange]]()
...@@ -2252,7 +2252,7 @@ class GPUModelRunner( ...@@ -2252,7 +2252,7 @@ class GPUModelRunner(
continue continue
mm_hashes.append(mm_feature.identifier) mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data) mm_kwargs.append((mm_feature.modality, mm_feature.data))
mm_lora_refs.append((req_id, mm_feature.mm_position)) mm_lora_refs.append((req_id, mm_feature.mm_position))
return mm_hashes, mm_kwargs, mm_lora_refs return mm_hashes, mm_kwargs, mm_lora_refs
...@@ -4475,12 +4475,10 @@ class GPUModelRunner( ...@@ -4475,12 +4475,10 @@ class GPUModelRunner(
# but not read from the cache # but not read from the cache
assert dummy_mm_item is not None, "Item should not already be cached" assert dummy_mm_item is not None, "Item should not already be cached"
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next( return next(
mm_kwargs_group mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
dummy_mm_items, [(modality, dummy_mm_item)] * max_items_per_batch,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment