"vscode:/vscode.git/clone" did not exist on "2af12d7a34b815b0d2aafe64e6bbfa157a3f278c"
Unverified Commit c6e7404c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Multimodal] Simplify MM input definitions (#33331)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 17b17c06
......@@ -23,18 +23,16 @@ from vllm.multimodal.inputs import (
)
def _dummy_elem(modality: str, key: str, size: int):
def _dummy_elem(size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(batch_size=1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems(
[_dummy_elem(modality, key, size) for key, size in size_by_key.items()]
def _dummy_item(size_by_key: dict[str, int]):
return MultiModalKwargsItem(
{key: _dummy_elem(size) for key, size in size_by_key.items()}
)
......@@ -61,7 +59,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
def test_minimal_put_get_cycle(self):
"""Test basic put and get operations."""
key = "test_key"
value = _dummy_item("text", {"field1": 10, "field2": 20})
value = _dummy_item({"field1": 10, "field2": 20})
# Put operation
address, monotonic_id = self.storage.put(key, value)
......
......@@ -119,7 +119,11 @@ def create_batched_mm_kwargs(
)["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality(
[item for modality in supported_mm_limits for item in mm_kwargs[modality]]
[
(modality, item)
for modality in supported_mm_limits
for item in mm_kwargs[modality]
]
)
......
......@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: np.random.RandomState | None = None,
......@@ -48,21 +46,18 @@ def _dummy_elem(
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))
return MultiModalFieldElem(
modality=modality,
key=key,
data=data,
field=MultiModalSharedField(batch_size=1),
)
def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItem.from_elems(
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()]
return MultiModalKwargsItem(
{key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()}
)
......@@ -71,19 +66,19 @@ def _dummy_items(
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItems.from_seq(
[
_dummy_item(modality, size_by_key, rng=rng)
return MultiModalKwargsItems(
{
modality: [_dummy_item(size_by_key, rng=rng)]
for modality, size_by_key in size_by_key_modality.items()
]
}
)
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_item({"a1": 100}), 100),
(_dummy_item({"a1": 100, "a2": 110}), 210),
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
......@@ -143,7 +138,7 @@ def _compare_caches(
rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
_dummy_item({"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
......@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
"image_C",
]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
h: MultiModalKwargsItem.dummy(nbytes=2 * base_item_size)
for h in request1_hashes
}
request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
request2_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
h: MultiModalKwargsItem.dummy(nbytes=1 * base_item_size)
for h in request2_hashes
}
......@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
):
request1_hashes = ["image_A", "image_B", "image_C"]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
for h in request1_hashes
h: MultiModalKwargsItem.dummy(5 * base_item_size) for h in request1_hashes
}
request1_items_p0_result = []
request2_hashes = ["image_G", "image_A"]
request2_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
(5 if h in request1_hashes else 2) * base_item_size
)
for h in request2_hashes
}
......@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
request3_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
(5 if h in request1_hashes else 2) * base_item_size
)
for h in request3_hashes
}
......@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
lora_a_identifier = f"12345:{base_mm_hash}"
lora_b_identifier = f"67890:{base_mm_hash}"
item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024)
item_data = MultiModalKwargsItem.dummy(1024)
feature_lora_a = MultiModalFeatureSpec(
data=item_data,
......
......@@ -77,7 +77,7 @@ def make_request(
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
......
......@@ -68,7 +68,7 @@ def make_request(
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
......
......@@ -56,7 +56,7 @@ def _create_random_request(
for j, position in enumerate(mm_positions):
identifier = f"{request_id}_hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
......
......@@ -1707,7 +1707,7 @@ def create_requests_with_priority(
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
......
......@@ -236,7 +236,7 @@ def create_requests(
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
......
......@@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# Step 1: Create initial request state with one multimodal feature
mm_feature_1 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="audio_1",
mm_position=PlaceholderRange(offset=2, length=10),
......@@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt with new multimodal feature)
mm_feature_2 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="audio_2",
mm_position=PlaceholderRange(offset=15, length=5),
......
......@@ -174,7 +174,7 @@ class TestStreamingScheduler(unittest.TestCase):
scheduler = create_scheduler()
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="",
mm_position=PlaceholderRange(offset=1, length=1),
......@@ -187,7 +187,7 @@ class TestStreamingScheduler(unittest.TestCase):
session.num_computed_tokens = len(session.prompt_token_ids)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="",
mm_position=PlaceholderRange(offset=2, length=1),
......
......@@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs():
e1 = MultiModalFieldElem(
"audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
......@@ -119,21 +115,20 @@ def test_multimodal_kwargs():
),
)
e3 = MultiModalFieldElem(
"image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
)
e4 = MultiModalFieldElem(
"image",
"i1",
torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
)
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargsItems.from_seq([audio, video, image])
mm = MultiModalKwargsItems(
{
"audio": [MultiModalKwargsItem({"a0": e1})],
"video": [MultiModalKwargsItem({"v0": e2})],
"image": [MultiModalKwargsItem({"i0": e3, "i1": e4})],
}
)
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
......@@ -147,8 +142,8 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14395, +-20 for minor changes
assert 14375 <= total_len <= 14425
# expected total encoding length, should be 14319, +-20 for minor changes
assert 14300 <= total_len <= 14340
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)
......
......@@ -463,8 +463,8 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
)
# cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
# cache prompt_updates for P0 only
self._p0_cache: dict[str, Sequence[ResolvedPromptUpdate]] = {}
self._hits = 0
self._total = 0
......@@ -495,23 +495,22 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates
prompt_updates = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
item, prompt_updates = mm_item
self._total += 1
try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
address, monotonic_id = self._shm_cache.put(mm_hash, item)
# Try to remove dangling items if p0 cache is too large.
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
self.remove_dangling_items()
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
address_item = self.address_as_item(
address, monotonic_id, mm_item[0].modality
)
return address_item, mm_item[1]
self._p0_cache[mm_hash] = prompt_updates
return self.address_as_item(address, monotonic_id), prompt_updates
except (ValueError, MemoryError) as e:
# put may fail if the object is too large or
# the cache is full.
......@@ -550,22 +549,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
del self._p0_cache[mm_hash]
def address_as_item(
self, address: int, monotonic_id: int, modality: str
self,
address: int,
monotonic_id: int,
) -> MultiModalKwargsItem:
addr_elem = MultiModalFieldElem(
modality=modality,
key="address",
data=address,
field=MultiModalBatchedField(),
)
id_elem = MultiModalFieldElem(
modality=modality,
key="monotonic_id",
data=monotonic_id,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
return mm_item
return MultiModalKwargsItem({"address": addr_elem, "monotonic_id": id_elem})
class BaseMultiModalReceiverCache(
......
......@@ -23,7 +23,7 @@ import numpy as np
from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar
from vllm.utils.collection_utils import full_groupby, is_list_of
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
from vllm.utils.jsontree import json_map_leaves
......@@ -336,25 +336,33 @@ class MultiModalFeatureSpec:
"""
Represents a single multimodal input with its processed data and metadata.
Used by the V1 engine to track multimodal data through processing and
caching. A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
Used to track multimodal data through processing and caching.
A request containing multiple multimodal items will have one
`MultiModalFeatureSpec` per item.
"""
data: Optional["MultiModalKwargsItem"]
"""Multimodal data for this feature"""
"""
Represents multimodal data for this feature.
Can be `None` if the item is cached, to skip IPC between API server
and engine core processes.
"""
modality: str
"""Based on the input, e.g., "image", "audio", "video"."""
"""The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
identifier: str
"""mm_hash or uuid for caching encoder outputs."""
"""The hash for caching encoder outputs (with LoRA prefix if applicable)."""
mm_position: PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""
"""
The location of the `modality` tokens corresponding to this item
in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
"""
mm_hash: str | None = None
"""Base mm_hash for processor cache (without LoRA prefix)."""
"""The hash for caching processor outputs (without LoRA prefix)."""
@staticmethod
def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
......@@ -373,23 +381,10 @@ class MultiModalFeatureSpec:
@dataclass
class MultiModalFieldElem:
"""
Represents a keyword argument inside a
Represents a processed keyword argument to pass to a model for a
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
"""
modality: str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key: str
"""
The key of this field in
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
"""
The tensor data of this field in
......@@ -417,11 +412,7 @@ class MultiModalFieldElem:
else:
data_equal = nested_tensors_equal(self.data, other.data)
return (
(self.modality, self.key) == (other.modality, other.key)
and data_equal
and type(self.field) is type(other.field)
) # noqa: E721
return data_equal and type(self.field) is type(other.field) # noqa: E721
@dataclass(frozen=True, kw_only=True)
......@@ -438,13 +429,8 @@ class BaseMultiModalField(ABC):
when `MultiModalKwargsItems.get_data()` is called to batch the data.
"""
def _field_factory(self, *, modality: str, key: str):
f = partial(
MultiModalFieldElem,
modality=modality,
key=key,
field=self,
)
def _field_factory(self):
f = partial(MultiModalFieldElem, field=self)
# Allow passing data as positional argument
def factory(data: NestedTensors) -> MultiModalFieldElem:
......@@ -519,7 +505,7 @@ class MultiModalBatchedField(BaseMultiModalField):
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
field_factory = self._field_factory()
return [field_factory(item) for item in data]
def _reduce_data(
......@@ -565,7 +551,7 @@ class MultiModalFlatField(BaseMultiModalField):
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
field_factory = self._field_factory()
if not is_list_of(self.slices, slice, check="all"):
assert isinstance(data, torch.Tensor), (
"torch.Tensor is required for multiple slices"
......@@ -664,7 +650,7 @@ class MultiModalSharedField(BaseMultiModalField):
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
field_factory = self._field_factory()
return [field_factory(data)] * self.batch_size
def _reduce_data(
......@@ -899,37 +885,19 @@ class MultiModalFieldConfig:
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
corresponding to a data item in
A dictionary of processed keyword arguments to pass to the model,
corresponding to a single item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
@staticmethod
def dummy(modality: str, nbytes: int = 1):
def dummy(nbytes: int = 1):
"""Convenience class for testing."""
mm_elem = MultiModalFieldElem(
modality=modality,
key="dummy",
data=torch.empty(nbytes, dtype=torch.uint8),
field=MultiModalSharedField(batch_size=1),
)
return MultiModalKwargsItem.from_elems([mm_elem])
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems})
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
super().__init__(data)
modalities = {elem.modality for elem in self.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
self._modality = next(iter(modalities))
@property
def modality(self) -> str:
return self._modality
return MultiModalKwargsItem({"dummy": mm_elem})
def get_data(self) -> dict[str, NestedTensors]:
return {key: elem.data for key, elem in self.items()}
......@@ -945,9 +913,38 @@ _I = TypeVar(
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
"""
A dictionary of
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
by modality.
A dictionary of processed multi-modal inputs by modality.
For example, given a processor that processes
images into `pixel_values` and `image_grid_thw`,
and audios into `input_audio_features`,
a prompt with 2 images and 1 audio will be processed
into a `MultiModalKwargsItems` with the following structure:
```python
MultiModalKwargsItems(
{
"image": [
# For the first image
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
# For the second imgae
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
],
"audio": [
# For the first audio
MultiModalKwargsItem({"input_audio_features": ...}),
],
}
)
```
Unlike HF processing which returns all items
in a single dictionary with batched keyword arguments,
we split up the items because some of them may already be cached.
Also, items from multiple requests may be batched together to improve throughput,
using the logic defined by the
[`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
for each keyword argument.
"""
@staticmethod
......@@ -967,7 +964,7 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
elems_by_key[key] = elems
keys_by_modality[config.modality].add(key)
items = list[MultiModalKwargsItem]()
items_by_modality = dict[str, list[MultiModalKwargsItem]]()
for modality, keys in keys_by_modality.items():
elems_in_modality = {k: elems_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
......@@ -979,15 +976,11 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
)
batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargsItems.from_seq(items)
items_by_modality[modality] = [
MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
for i in range(batch_size)
]
@staticmethod
def from_seq(items: Sequence[MultiModalKwargsItem]):
items_by_modality = full_groupby(items, key=lambda x: x.modality)
return MultiModalKwargsItems(items_by_modality)
def __getitem__(self, modality: str) -> Sequence[_I]:
......
......@@ -467,7 +467,7 @@ def argsort_mm_positions(
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
......@@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality(
"""
from vllm.multimodal.inputs import MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
items_lst = [item for _, item in group]
mm_kwargs_items = MultiModalKwargsItems({modality: items_lst})
mm_kwargs_data = mm_kwargs_items.get_data(
device=device,
pin_memory=pin_memory,
......
......@@ -242,13 +242,11 @@ class MsgpackEncoder:
for modality, itemlist in items.items()
}
def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
def _encode_mm_item(self, item: MultiModalKwargsItem) -> dict[str, Any]:
return {key: self._encode_mm_field_elem(elem) for key, elem in item.items()}
def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
return {
"modality": elem.modality,
"key": elem.key,
"data": (
None if elem.data is None else self._encode_nested_tensors(elem.data)
),
......@@ -383,9 +381,9 @@ class MsgpackDecoder:
}
)
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems(
[self._decode_mm_field_elem(v) for v in obj]
def _decode_mm_item(self, obj: dict[str, Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem(
{key: self._decode_mm_field_elem(elem) for key, elem in obj.items()}
)
def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
......
......@@ -43,9 +43,9 @@ class EncoderRunner:
def prepare_mm_inputs(
self,
scheduled_encoder_inputs: dict[str, list[int]],
) -> tuple[list[str], list[MultiModalKwargsItem]]:
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
mm_hashes: list[str] = []
mm_kwargs: list[MultiModalKwargsItem] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
for mm_input_id in encoder_input_ids:
......@@ -53,7 +53,8 @@ class EncoderRunner:
if mm_feature.data is None:
continue
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data)
mm_kwargs.append((mm_feature.modality, mm_feature.data))
return mm_hashes, mm_kwargs
@torch.inference_mode()
......@@ -61,7 +62,7 @@ class EncoderRunner:
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[MultiModalKwargsItem],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
......
......@@ -1217,11 +1217,11 @@ class GPUModelRunner(
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
return {}
mm_kwargs = list[MultiModalKwargsItem]()
mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
for req in scheduler_output.scheduled_new_reqs:
for feature in req.mm_features:
if feature.data is not None:
mm_kwargs.append(feature.data)
mm_kwargs.append((feature.modality, feature.data))
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
......@@ -2219,7 +2219,7 @@ class GPUModelRunner(
scheduler_output: "SchedulerOutput",
) -> tuple[
list[str],
list[MultiModalKwargsItem],
list[tuple[str, MultiModalKwargsItem]],
list[tuple[str, PlaceholderRange]],
]:
"""Batch multimodal inputs from scheduled encoder inputs.
......@@ -2239,7 +2239,7 @@ class GPUModelRunner(
return [], [], []
mm_hashes = list[str]()
mm_kwargs = list[MultiModalKwargsItem]()
mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
# Multimodal LoRA reference info to map each multimodal item
# back to its request & position
mm_lora_refs = list[tuple[str, PlaceholderRange]]()
......@@ -2252,7 +2252,7 @@ class GPUModelRunner(
continue
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data)
mm_kwargs.append((mm_feature.modality, mm_feature.data))
mm_lora_refs.append((req_id, mm_feature.mm_position))
return mm_hashes, mm_kwargs, mm_lora_refs
......@@ -4475,12 +4475,10 @@ class GPUModelRunner(
# but not read from the cache
assert dummy_mm_item is not None, "Item should not already be cached"
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(
mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
dummy_mm_items,
[(modality, dummy_mm_item)] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment