Unverified Commit 671427ef authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Move `multimodal_cpu_fields` definition to field config (#30181)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 21bb3235
...@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int): ...@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
modality=modality, modality=modality,
key=key, key=key,
data=torch.empty((size,), dtype=torch.int8), data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(1), field=MultiModalSharedField(batch_size=1),
) )
......
...@@ -51,7 +51,7 @@ def _dummy_elem( ...@@ -51,7 +51,7 @@ def _dummy_elem(
modality=modality, modality=modality,
key=key, key=key,
data=data, data=data,
field=MultiModalSharedField(1), field=MultiModalSharedField(batch_size=1),
) )
......
...@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct): ...@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs(): def test_multimodal_kwargs():
e1 = MultiModalFieldElem( e1 = MultiModalFieldElem(
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField() "audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
) )
e2 = MultiModalFieldElem( e2 = MultiModalFieldElem(
"video", "video",
"v0", "v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)], [torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
dim=0,
),
) )
e3 = MultiModalFieldElem( e3 = MultiModalFieldElem(
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4) "image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
) )
e4 = MultiModalFieldElem( e4 = MultiModalFieldElem(
"image", "image",
"i1", "i1",
torch.zeros(1000, dtype=torch.int32), torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2), MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
) )
audio = MultiModalKwargsItem.from_elems([e1]) audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2]) video = MultiModalKwargsItem.from_elems([e2])
...@@ -138,8 +147,8 @@ def test_multimodal_kwargs(): ...@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14306, +-20 for minor changes # expected total encoding length, should be 14395, +-20 for minor changes
assert 14275 <= total_len <= 14325 assert 14375 <= total_len <= 14425
decoded = decoder.decode(encoded).mm[0] decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems) assert isinstance(decoded, MultiModalKwargsItems)
......
...@@ -787,10 +787,10 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -787,10 +787,10 @@ class Glm4vVisionTransformer(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
grid_thw: list[list[int]], grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now) if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long) grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
# patchify # patchify
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
...@@ -805,7 +805,8 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -805,7 +805,8 @@ class Glm4vVisionTransformer(nn.Module):
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32) ).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
...@@ -1548,7 +1549,6 @@ class Glm4vForConditionalGeneration( ...@@ -1548,7 +1549,6 @@ class Glm4vForConditionalGeneration(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"] grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
...@@ -1559,12 +1559,10 @@ class Glm4vForConditionalGeneration( ...@@ -1559,12 +1559,10 @@ class Glm4vForConditionalGeneration(
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
) )
else: else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist()) image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return image_embeds.split(sizes) return image_embeds.split(sizes)
def _process_video_input( def _process_video_input(
...@@ -1572,7 +1570,6 @@ class Glm4vForConditionalGeneration( ...@@ -1572,7 +1570,6 @@ class Glm4vForConditionalGeneration(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"] grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"].type(self.visual.dtype)
...@@ -1588,15 +1585,11 @@ class Glm4vForConditionalGeneration( ...@@ -1588,15 +1585,11 @@ class Glm4vForConditionalGeneration(
rope_type="rope_3d", rope_type="rope_3d",
) )
else: else:
video_embeds = self.visual( video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
pixel_values_videos, grid_thw=grid_thw.tolist()
)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return video_embeds.split(sizes) return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
......
...@@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict( return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"), image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
) )
...@@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration( ...@@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration(
SupportsQuant, SupportsQuant,
SupportsXDRoPE, SupportsXDRoPE,
): ):
multimodal_cpu_fields = {"image_grid_thw"}
# To ensure correct weight loading and mapping. # To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
......
...@@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol): ...@@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol):
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use. `vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
""" """
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset() multimodal_cpu_fields: ClassVar[Set[str] | None] = None
""" """
A set indicating CPU-only multimodal fields. [DEPRECATED] A set indicating CPU-only multimodal fields.
""" """
_processor_factory: ClassVar[_ProcessorFactories] _processor_factory: ClassVar[_ProcessorFactories]
...@@ -279,6 +279,15 @@ def supports_multimodal( ...@@ -279,6 +279,15 @@ def supports_multimodal(
"please remove the override from your model." "please remove the override from your model."
) )
multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
if multimodal_cpu_fields is not None:
raise ValueError(
"`multimodal_cpu_fields` is no longer effective, "
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
"and then remove the override from your model."
)
return res return res
......
...@@ -201,8 +201,6 @@ class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder): ...@@ -201,8 +201,6 @@ class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder):
dummy_inputs=OpenCUADummyInputsBuilder, dummy_inputs=OpenCUADummyInputsBuilder,
) )
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_cpu_fields = {"image_grid_thw"}
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
......
...@@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration(
SupportsMultiModalPruning, SupportsMultiModalPruning,
SupportsMRoPE, SupportsMRoPE,
): ):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
......
...@@ -811,14 +811,14 @@ def _create_qwen2vl_field_factory( ...@@ -811,14 +811,14 @@ def _create_qwen2vl_field_factory(
image_embeds=MultiModalFieldConfig.flat_from_sizes( image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes "image", image_embed_grid_sizes
), ),
image_grid_thw=MultiModalFieldConfig.batched("image"), image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes "video", video_grid_sizes
), ),
video_embeds=MultiModalFieldConfig.flat_from_sizes( video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes "video", video_embed_grid_sizes
), ),
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
) )
return _qwen2vl_field_config return _qwen2vl_field_config
...@@ -1131,8 +1131,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]) ...@@ -1131,8 +1131,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
class Qwen2VLForConditionalGeneration( class Qwen2VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
): ):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
# To ensure correct weight loading and mapping. # To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
...@@ -1393,9 +1391,11 @@ class Qwen2VLForConditionalGeneration( ...@@ -1393,9 +1391,11 @@ class Qwen2VLForConditionalGeneration(
else: else:
pixel_values_videos = video_input["pixel_values_videos"] pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel: if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" self.visual,
pixel_values_videos,
grid_thw.tolist(),
rope_type="rope_3d",
) )
else: else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
......
...@@ -984,14 +984,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -984,14 +984,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
image_embeds=MultiModalFieldConfig.flat_from_sizes( image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes "image", image_grid_sizes
), ),
image_grid_thw=MultiModalFieldConfig.batched("image"), image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes "video", video_grid_sizes
), ),
video_embeds=MultiModalFieldConfig.flat_from_sizes( video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes "video", video_grid_sizes
), ),
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration( ...@@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration(
SupportsMRoPE, SupportsMRoPE,
SupportsEagle3, SupportsEagle3,
): ):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence, Set from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from itertools import accumulate from itertools import accumulate
...@@ -223,6 +223,23 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: ...@@ -223,6 +223,23 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return a == b return a == b
def _nested_tensors_h2d(
tensors: NestedTensors,
device: torch.types.Device,
) -> NestedTensors:
if device is None:
return tensors
return json_map_leaves(
(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x
),
tensors,
)
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
""" """
A dictionary containing nested tensors which have been batched via A dictionary containing nested tensors which have been batched via
...@@ -334,7 +351,7 @@ class MultiModalFieldElem: ...@@ -334,7 +351,7 @@ class MultiModalFieldElem:
) # noqa: E721 ) # noqa: E721
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class BaseMultiModalField(ABC): class BaseMultiModalField(ABC):
""" """
Defines how to interpret tensor data belonging to a keyword argument in Defines how to interpret tensor data belonging to a keyword argument in
...@@ -342,6 +359,12 @@ class BaseMultiModalField(ABC): ...@@ -342,6 +359,12 @@ class BaseMultiModalField(ABC):
multi-modal items, and vice versa. multi-modal items, and vice versa.
""" """
keep_on_cpu: bool = False
"""
If `True`, then this field is excluded from being moved to the accelerator
when `MultiModalKwargsItems.get_data()` is called to batch the data.
"""
def _field_factory(self, *, modality: str, key: str): def _field_factory(self, *, modality: str, key: str):
f = partial( f = partial(
MultiModalFieldElem, MultiModalFieldElem,
...@@ -386,6 +409,7 @@ class BaseMultiModalField(ABC): ...@@ -386,6 +409,7 @@ class BaseMultiModalField(ABC):
self, self,
elems: list[MultiModalFieldElem], elems: list[MultiModalFieldElem],
*, *,
device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
) -> NestedTensors: ) -> NestedTensors:
""" """
...@@ -399,11 +423,17 @@ class BaseMultiModalField(ABC): ...@@ -399,11 +423,17 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1: if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}") raise ValueError(f"Cannot merge different {field_types=}")
if device is not None and self.keep_on_cpu:
device = "cpu"
if pin_memory and self.keep_on_cpu:
pin_memory = False
batch = [elem.data for elem in elems] batch = [elem.data for elem in elems]
return self._reduce_data(batch, pin_memory=pin_memory) out = self._reduce_data(batch, pin_memory=pin_memory)
return _nested_tensors_h2d(out, device=device)
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class MultiModalBatchedField(BaseMultiModalField): class MultiModalBatchedField(BaseMultiModalField):
""" """
Info: Info:
...@@ -445,7 +475,7 @@ class MultiModalBatchedField(BaseMultiModalField): ...@@ -445,7 +475,7 @@ class MultiModalBatchedField(BaseMultiModalField):
return batch return batch
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class MultiModalFlatField(BaseMultiModalField): class MultiModalFlatField(BaseMultiModalField):
""" """
Info: Info:
...@@ -505,7 +535,7 @@ class MultiModalFlatField(BaseMultiModalField): ...@@ -505,7 +535,7 @@ class MultiModalFlatField(BaseMultiModalField):
return [e for elem in batch for e in elem] return [e for elem in batch for e in elem]
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class MultiModalSharedField(BaseMultiModalField): class MultiModalSharedField(BaseMultiModalField):
""" """
Info: Info:
...@@ -532,9 +562,10 @@ class MultiModalSharedField(BaseMultiModalField): ...@@ -532,9 +562,10 @@ class MultiModalSharedField(BaseMultiModalField):
return batch[0] return batch[0]
@dataclass(frozen=True)
class MultiModalFieldConfig: class MultiModalFieldConfig:
@staticmethod @staticmethod
def batched(modality: str): def batched(modality: str, *, keep_on_cpu: bool = False):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data. indexing into the first dimension of the underlying data.
...@@ -542,6 +573,7 @@ class MultiModalFieldConfig: ...@@ -542,6 +573,7 @@ class MultiModalFieldConfig:
Args: Args:
modality: The modality of the multi-modal item that uses this modality: The modality of the multi-modal item that uses this
keyword argument. keyword argument.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -558,7 +590,7 @@ class MultiModalFieldConfig: ...@@ -558,7 +590,7 @@ class MultiModalFieldConfig:
``` ```
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalBatchedField(), field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
modality=modality, modality=modality,
) )
...@@ -567,6 +599,8 @@ class MultiModalFieldConfig: ...@@ -567,6 +599,8 @@ class MultiModalFieldConfig:
modality: str, modality: str,
slices: Sequence[slice] | Sequence[Sequence[slice]], slices: Sequence[slice] | Sequence[Sequence[slice]],
dim: int = 0, dim: int = 0,
*,
keep_on_cpu: bool = False,
): ):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
...@@ -579,6 +613,7 @@ class MultiModalFieldConfig: ...@@ -579,6 +613,7 @@ class MultiModalFieldConfig:
slices (dim>0) that is used to extract the data corresponding slices (dim>0) that is used to extract the data corresponding
to it. to it.
dim: The dimension to extract data, default to 0. dim: The dimension to extract data, default to 0.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -613,12 +648,22 @@ class MultiModalFieldConfig: ...@@ -613,12 +648,22 @@ class MultiModalFieldConfig:
``` ```
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices, dim=dim), field=MultiModalFlatField(
slices=slices,
dim=dim,
keep_on_cpu=keep_on_cpu,
),
modality=modality, modality=modality,
) )
@staticmethod @staticmethod
def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0): def flat_from_sizes(
modality: str,
size_per_item: "torch.Tensor",
dim: int = 0,
*,
keep_on_cpu: bool = False,
):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data. slicing along the first dimension of the underlying data.
...@@ -629,6 +674,7 @@ class MultiModalFieldConfig: ...@@ -629,6 +674,7 @@ class MultiModalFieldConfig:
size_per_item: For each multi-modal item, the size of the slice size_per_item: For each multi-modal item, the size of the slice
that is used to extract the data corresponding to it. that is used to extract the data corresponding to it.
dim: The dimension to slice, default to 0. dim: The dimension to slice, default to 0.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -676,10 +722,20 @@ class MultiModalFieldConfig: ...@@ -676,10 +722,20 @@ class MultiModalFieldConfig:
for i in range(len(size_per_item)) for i in range(len(size_per_item))
] ]
return MultiModalFieldConfig.flat(modality, slices, dim=dim) return MultiModalFieldConfig.flat(
modality,
slices,
dim=dim,
keep_on_cpu=keep_on_cpu,
)
@staticmethod @staticmethod
def shared(modality: str, batch_size: int): def shared(
modality: str,
batch_size: int,
*,
keep_on_cpu: bool = False,
):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data. taking the entirety of the underlying data.
...@@ -690,6 +746,7 @@ class MultiModalFieldConfig: ...@@ -690,6 +746,7 @@ class MultiModalFieldConfig:
modality: The modality of the multi-modal item that uses this modality: The modality of the multi-modal item that uses this
keyword argument. keyword argument.
batch_size: The number of multi-modal items which share this data. batch_size: The number of multi-modal items which share this data.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -708,18 +765,15 @@ class MultiModalFieldConfig: ...@@ -708,18 +765,15 @@ class MultiModalFieldConfig:
``` ```
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size), field=MultiModalSharedField(
batch_size=batch_size,
keep_on_cpu=keep_on_cpu,
),
modality=modality, modality=modality,
) )
def __init__(self, field: BaseMultiModalField, modality: str) -> None: field: BaseMultiModalField
super().__init__() modality: str
self.field = field
self.modality = modality
def __repr__(self) -> str:
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"
def build_elems( def build_elems(
self, self,
...@@ -744,7 +798,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): ...@@ -744,7 +798,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
modality=modality, modality=modality,
key="dummy", key="dummy",
data=torch.empty(nbytes, dtype=torch.uint8), data=torch.empty(nbytes, dtype=torch.uint8),
field=MultiModalSharedField(1), field=MultiModalSharedField(batch_size=1),
) )
return MultiModalKwargsItem.from_elems([mm_elem]) return MultiModalKwargsItem.from_elems([mm_elem])
...@@ -844,7 +898,6 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): ...@@ -844,7 +898,6 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
*, *,
device: torch.types.Device = None, device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
cpu_fields: Set[str] = frozenset(),
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
"""Construct a dictionary of keyword arguments to pass to the model.""" """Construct a dictionary of keyword arguments to pass to the model."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
...@@ -859,21 +912,14 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): ...@@ -859,21 +912,14 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
elems_by_key[key].append(elem) elems_by_key[key].append(elem)
data = { data = {
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) key: elems[0].field.reduce_data(
elems,
device=device,
pin_memory=pin_memory,
)
for key, elems in elems_by_key.items() for key, elems in elems_by_key.items()
} }
if device is not None:
for k in data.keys() - cpu_fields:
data[k] = json_map_leaves(
(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x
),
data[k],
)
return data return data
......
...@@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality( ...@@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality(
device: torch.types.Device = None, device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
merge_by_field_config: bool | None = None, merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(), multimodal_cpu_fields: Set[str] | None = None,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]: ) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance. modality together into the same `MultiModalKwargs` instance.
...@@ -431,6 +431,11 @@ def group_mm_kwargs_by_modality( ...@@ -431,6 +431,11 @@ def group_mm_kwargs_by_modality(
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` " "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13." "is deprecated and will be removed in v0.13."
) )
if multimodal_cpu_fields is not None:
logger.warning_once(
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13."
)
from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalKwargsItems
...@@ -440,7 +445,6 @@ def group_mm_kwargs_by_modality( ...@@ -440,7 +445,6 @@ def group_mm_kwargs_by_modality(
mm_kwargs_data = mm_kwargs_items.get_data( mm_kwargs_data = mm_kwargs_items.get_data(
device=device, device=device,
pin_memory=pin_memory, pin_memory=pin_memory,
cpu_fields=multimodal_cpu_fields,
) )
yield modality, len(items_lst), mm_kwargs_data yield modality, len(items_lst), mm_kwargs_data
......
...@@ -269,10 +269,11 @@ class MsgpackEncoder: ...@@ -269,10 +269,11 @@ class MsgpackEncoder:
name = MMF_CLASS_TO_FACTORY.get(field.__class__) name = MMF_CLASS_TO_FACTORY.get(field.__class__)
if not name: if not name:
raise TypeError(f"Unsupported field type: {field.__class__}") raise TypeError(f"Unsupported field type: {field.__class__}")
# We just need to copy all of the field values in order # We just need to copy all of the field values in order
# which will be then used to reconstruct the field. # which will be then used to reconstruct the field.
field_values = (getattr(field, f.name) for f in dataclasses.fields(field)) factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
return name, *field_values return name, factory_kw
class MsgpackDecoder: class MsgpackDecoder:
...@@ -392,15 +393,15 @@ class MsgpackDecoder: ...@@ -392,15 +393,15 @@ class MsgpackDecoder:
obj["data"] = self._decode_nested_tensors(obj["data"]) obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig # Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"] factory_meth_name, factory_kw = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
# Special case: decode the union "slices" field of # Special case: decode the union "slices" field of
# MultiModalFlatField # MultiModalFlatField
if factory_meth_name == "flat": if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0]) factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
obj["field"] = factory_meth(None, *field_args).field obj["field"] = factory_meth("", **factory_kw).field
return MultiModalFieldElem(**obj) return MultiModalFieldElem(**obj)
def _decode_nested_tensors(self, obj: Any) -> NestedTensors: def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
......
...@@ -1097,7 +1097,6 @@ class GPUModelRunner( ...@@ -1097,7 +1097,6 @@ class GPUModelRunner(
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
mm_kwargs_combined.update(mm_kwargs_group) mm_kwargs_combined.update(mm_kwargs_group)
...@@ -2109,7 +2108,6 @@ class GPUModelRunner( ...@@ -2109,7 +2108,6 @@ class GPUModelRunner(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
curr_group_outputs: list[torch.Tensor] = [] curr_group_outputs: list[torch.Tensor] = []
...@@ -2135,7 +2133,6 @@ class GPUModelRunner( ...@@ -2135,7 +2133,6 @@ class GPUModelRunner(
[video_mm_kwargs_item], [video_mm_kwargs_item],
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
) )
) )
...@@ -3887,14 +3884,12 @@ class GPUModelRunner( ...@@ -3887,14 +3884,12 @@ class GPUModelRunner(
dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch dummy_mm_items = [dummy_mm_item] * max_items_per_batch
model = cast(SupportsMultiModal, self.model)
return next( return next(
mm_kwargs_group mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
dummy_mm_items, dummy_mm_items,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
) )
) )
......
...@@ -969,7 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -969,7 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
...@@ -2050,14 +2049,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2050,14 +2049,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch dummy_mm_items = [dummy_mm_item] * max_items_per_batch
model = cast(SupportsMultiModal, self.model)
return next( return next(
grouped_mm_kwargs grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
dummy_mm_items, dummy_mm_items,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment