Unverified Commit 6dd30265 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Rename `group_mm_kwargs_by_modality -> group_and_batch_mm_kwargs` (#36158)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent de00ebea
...@@ -27,7 +27,7 @@ from vllm.distributed import ( ...@@ -27,7 +27,7 @@ from vllm.distributed import (
from vllm.model_executor.models.interfaces import supports_multimodal from vllm.model_executor.models.interfaces import supports_multimodal
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
...@@ -114,7 +114,7 @@ def create_batched_mm_kwargs( ...@@ -114,7 +114,7 @@ def create_batched_mm_kwargs(
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)["mm_kwargs"].require_data() )["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality( return group_and_batch_mm_kwargs(
[ [
(modality, item) (modality, item)
for modality in supported_mm_limits for modality in supported_mm_limits
......
...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any ...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image from PIL import Image
from typing_extensions import deprecated
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
...@@ -207,7 +208,7 @@ def group_and_batch_mm_items( ...@@ -207,7 +208,7 @@ def group_and_batch_mm_items(
assert start_idx == len(items) assert start_idx == len(items)
def group_mm_kwargs_by_modality( def group_and_batch_mm_kwargs(
mm_kwargs: list[tuple[str, MultiModalKwargsItem]], mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*, *,
device: torch.types.Device = None, device: torch.types.Device = None,
...@@ -246,6 +247,19 @@ def group_mm_kwargs_by_modality( ...@@ -246,6 +247,19 @@ def group_mm_kwargs_by_modality(
yield modality, num_items, mm_kwargs_batch yield modality, num_items, mm_kwargs_batch
@deprecated(
"`group_mm_kwargs_by_modality` has been renamed to `group_and_batch_mm_kwargs`. "
"The old name will be removed in v0.19."
)
def group_mm_kwargs_by_modality(
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
return group_and_batch_mm_kwargs(mm_kwargs, device=device, pin_memory=pin_memory)
def fetch_audio( def fetch_audio(
audio_url: str, audio_url: str,
audio_io_kwargs: dict[str, Any] | None = None, audio_io_kwargs: dict[str, Any] | None = None,
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalKwargsItem from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
...@@ -53,14 +53,12 @@ class EncoderRunner: ...@@ -53,14 +53,12 @@ class EncoderRunner:
mm_kwargs: list[tuple[str, MultiModalKwargsItem]], mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = [] encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs, device=self.device, pin_memory=False mm_kwargs, device=self.device, pin_memory=False
): ):
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group) batch_outputs = self.model.embed_multimodal(**mm_kwargs_batch)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
curr_group_outputs, expected_num_items=num_items encoder_outputs.extend(batch_outputs)
)
encoder_outputs.extend(curr_group_outputs)
return encoder_outputs return encoder_outputs
def gather_mm_embeddings( def gather_mm_embeddings(
......
...@@ -93,7 +93,7 @@ from vllm.multimodal.inputs import ( ...@@ -93,7 +93,7 @@ from vllm.multimodal.inputs import (
MultiModalKwargsItem, MultiModalKwargsItem,
PlaceholderRange, PlaceholderRange,
) )
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -1311,12 +1311,12 @@ class GPUModelRunner( ...@@ -1311,12 +1311,12 @@ class GPUModelRunner(
# Input all modalities at once # Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {} mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
): ):
mm_kwargs_combined.update(mm_kwargs_group) mm_kwargs_combined.update(mm_kwargs_batch)
return mm_kwargs_combined return mm_kwargs_combined
...@@ -2446,12 +2446,12 @@ class GPUModelRunner( ...@@ -2446,12 +2446,12 @@ class GPUModelRunner(
encoder_outputs: list[torch.Tensor] = [] encoder_outputs: list[torch.Tensor] = []
# Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
current_item_idx = 0 current_item_idx = 0
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
): ):
curr_group_outputs: MultiModalEmbeddings batch_outputs: MultiModalEmbeddings
# EVS-related change. # EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when # (ekhvedchenia): Temporary hack to limit peak memory usage when
...@@ -2467,14 +2467,14 @@ class GPUModelRunner( ...@@ -2467,14 +2467,14 @@ class GPUModelRunner(
and modality == "video" and modality == "video"
and num_items > 1 and num_items > 1
): ):
curr_group_outputs_lst = list[torch.Tensor]() batch_outputs_lst = list[torch.Tensor]()
for video_idx in range(num_items): for video_idx in range(num_items):
video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx] video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
with self.timed_encoder_operation( with self.timed_encoder_operation(
should_time, mm_lora_refs, current_item_idx + video_idx, 1 should_time, mm_lora_refs, current_item_idx + video_idx, 1
): ):
_, _, micro_batch_mm_inputs = next( _, _, micro_batch_mm_inputs = next(
group_mm_kwargs_by_modality( group_and_batch_mm_kwargs(
[video_mm_kwargs_item], [video_mm_kwargs_item],
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
...@@ -2485,12 +2485,12 @@ class GPUModelRunner( ...@@ -2485,12 +2485,12 @@ class GPUModelRunner(
**micro_batch_mm_inputs **micro_batch_mm_inputs
) )
curr_group_outputs_lst.extend(micro_batch_outputs) batch_outputs_lst.extend(micro_batch_outputs)
curr_group_outputs = curr_group_outputs_lst batch_outputs = batch_outputs_lst
else: else:
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `batch_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size) # 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items. # in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, # 2. A list or tuple (length: num_items) of tensors,
...@@ -2500,13 +2500,10 @@ class GPUModelRunner( ...@@ -2500,13 +2500,10 @@ class GPUModelRunner(
with self.timed_encoder_operation( with self.timed_encoder_operation(
should_time, mm_lora_refs, current_item_idx, num_items should_time, mm_lora_refs, current_item_idx, num_items
): ):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
curr_group_outputs, encoder_outputs.extend(batch_outputs)
expected_num_items=num_items,
)
encoder_outputs.extend(curr_group_outputs)
current_item_idx += num_items current_item_idx += num_items
...@@ -4707,8 +4704,8 @@ class GPUModelRunner( ...@@ -4707,8 +4704,8 @@ class GPUModelRunner(
assert dummy_mm_item is not None, "Item should not already be cached" assert dummy_mm_item is not None, "Item should not already be cached"
return next( return next(
mm_kwargs_group mm_kwargs_batch
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs(
[(modality, dummy_mm_item)] * max_items_per_batch, [(modality, dummy_mm_item)] * max_items_per_batch,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment