Unverified Commit 6dd30265 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Rename `group_mm_kwargs_by_modality -> group_and_batch_mm_kwargs` (#36158)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent de00ebea
......@@ -27,7 +27,7 @@ from vllm.distributed import (
from vllm.model_executor.models.interfaces import supports_multimodal
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.platforms import current_platform
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of
......@@ -114,7 +114,7 @@ def create_batched_mm_kwargs(
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality(
return group_and_batch_mm_kwargs(
[
(modality, item)
for modality in supported_mm_limits
......
......@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import numpy.typing as npt
from PIL import Image
from typing_extensions import deprecated
from vllm.utils.import_utils import LazyLoader
......@@ -207,7 +208,7 @@ def group_and_batch_mm_items(
assert start_idx == len(items)
def group_mm_kwargs_by_modality(
def group_and_batch_mm_kwargs(
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*,
device: torch.types.Device = None,
......@@ -246,6 +247,19 @@ def group_mm_kwargs_by_modality(
yield modality, num_items, mm_kwargs_batch
@deprecated(
"`group_mm_kwargs_by_modality` has been renamed to `group_and_batch_mm_kwargs`. "
"The old name will be removed in v0.19."
)
def group_mm_kwargs_by_modality(
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
return group_and_batch_mm_kwargs(mm_kwargs, device=device, pin_memory=pin_memory)
def fetch_audio(
audio_url: str,
audio_io_kwargs: dict[str, Any] | None = None,
......
......@@ -5,7 +5,7 @@ import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
......@@ -53,14 +53,12 @@ class EncoderRunner:
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
batch_outputs = self.model.embed_multimodal(**mm_kwargs_batch)
sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
encoder_outputs.extend(batch_outputs)
return encoder_outputs
def gather_mm_embeddings(
......
......@@ -93,7 +93,7 @@ from vllm.multimodal.inputs import (
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
......@@ -1311,12 +1311,12 @@ class GPUModelRunner(
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
mm_kwargs_combined.update(mm_kwargs_batch)
return mm_kwargs_combined
......@@ -2446,12 +2446,12 @@ class GPUModelRunner(
encoder_outputs: list[torch.Tensor] = []
# Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
current_item_idx = 0
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
curr_group_outputs: MultiModalEmbeddings
batch_outputs: MultiModalEmbeddings
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
......@@ -2467,14 +2467,14 @@ class GPUModelRunner(
and modality == "video"
and num_items > 1
):
curr_group_outputs_lst = list[torch.Tensor]()
batch_outputs_lst = list[torch.Tensor]()
for video_idx in range(num_items):
video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
with self.timed_encoder_operation(
should_time, mm_lora_refs, current_item_idx + video_idx, 1
):
_, _, micro_batch_mm_inputs = next(
group_mm_kwargs_by_modality(
group_and_batch_mm_kwargs(
[video_mm_kwargs_item],
device=self.device,
pin_memory=self.pin_memory,
......@@ -2485,12 +2485,12 @@ class GPUModelRunner(
**micro_batch_mm_inputs
)
curr_group_outputs_lst.extend(micro_batch_outputs)
batch_outputs_lst.extend(micro_batch_outputs)
curr_group_outputs = curr_group_outputs_lst
batch_outputs = batch_outputs_lst
else:
# Run the encoder.
# `curr_group_outputs` is either of the following:
# `batch_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors,
......@@ -2500,13 +2500,10 @@ class GPUModelRunner(
with self.timed_encoder_operation(
should_time, mm_lora_refs, current_item_idx, num_items
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=num_items,
)
encoder_outputs.extend(curr_group_outputs)
sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
encoder_outputs.extend(batch_outputs)
current_item_idx += num_items
......@@ -4707,8 +4704,8 @@ class GPUModelRunner(
assert dummy_mm_item is not None, "Item should not already be cached"
return next(
mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs_batch
for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs(
[(modality, dummy_mm_item)] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment