Unverified Commit 19b927e5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Use individual MM items in P0/P1 cache and model runner (#22570)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 20d65aa7
...@@ -113,6 +113,9 @@ class MsgpackEncoder: ...@@ -113,6 +113,9 @@ class MsgpackEncoder:
int(v) if v is not None else None int(v) if v is not None else None
for v in (obj.start, obj.stop, obj.step)) for v in (obj.start, obj.stop, obj.step))
if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargs): if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj mm: MultiModalKwargs = obj
if not mm.modalities: if not mm.modalities:
...@@ -120,17 +123,12 @@ class MsgpackEncoder: ...@@ -120,17 +123,12 @@ class MsgpackEncoder:
return dict(mm) return dict(mm)
# ignore the main dict, it will be re-indexed. # ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored. # Any tensors *not* indexed by modality will be ignored.
return [[{ return [
"modality": elem.modality, self._encode_mm_item(item)
"key": elem.key,
"data": self._encode_nested_tensors(elem.data),
"field": self._encode_mm_field(elem.field),
} for elem in item.values()]
for itemlist in mm._items_by_modality.values() for itemlist in mm._items_by_modality.values()
for item in itemlist] for item in itemlist
]
if isinstance(obj, UtilityResult): if isinstance(obj, UtilityResult):
result = obj.result result = obj.result
...@@ -192,6 +190,23 @@ class MsgpackEncoder: ...@@ -192,6 +190,23 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.") dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data return dtype, obj.shape, data
def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
def _encode_mm_field_elem(self,
elem: MultiModalFieldElem) -> dict[str, Any]:
return {
"modality":
elem.modality,
"key":
elem.key,
"data": (None if elem.data is None else
self._encode_nested_tensors(elem.data)),
"field":
self._encode_mm_field(elem.field),
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any: def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor): if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt) return self._encode_tensor(nt)
...@@ -250,6 +265,8 @@ class MsgpackDecoder: ...@@ -250,6 +265,8 @@ class MsgpackDecoder:
return self._decode_tensor(obj) return self._decode_tensor(obj)
if t is slice: if t is slice:
return slice(*obj) return slice(*obj)
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargs): if issubclass(t, MultiModalKwargs):
if isinstance(obj, list): if isinstance(obj, list):
return MultiModalKwargs.from_items( return MultiModalKwargs.from_items(
...@@ -311,15 +328,18 @@ class MsgpackDecoder: ...@@ -311,15 +328,18 @@ class MsgpackDecoder:
# Convert back to proper shape & type # Convert back to proper shape & type
return arr.view(torch_dtype).view(shape) return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
return [self._decode_mm_item(v) for v in obj] return [self._decode_mm_item(v) for v in obj]
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem: def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems( return MultiModalKwargsItem.from_elems(
[self._decode_mm_field_elem(v) for v in obj]) [self._decode_mm_field_elem(v) for v in obj])
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem: def _decode_mm_field_elem(self, obj: dict[str,
Any]) -> MultiModalFieldElem:
if obj["data"] is not None:
obj["data"] = self._decode_nested_tensors(obj["data"]) obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig # Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"] factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
......
...@@ -7,9 +7,11 @@ from typing import Optional, cast ...@@ -7,9 +7,11 @@ from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
...@@ -29,7 +31,7 @@ class CachedRequestState: ...@@ -29,7 +31,7 @@ class CachedRequestState:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs] mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
...@@ -51,6 +53,13 @@ class CachedRequestState: ...@@ -51,6 +53,13 @@ class CachedRequestState:
def num_tokens(self) -> int: def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids) return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
def get_token_id(self, idx: int) -> int: def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx] return self.prompt_token_ids[idx]
......
...@@ -40,9 +40,9 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts, ...@@ -40,9 +40,9 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model) VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
...@@ -478,7 +478,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -478,7 +478,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs, mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
...@@ -496,18 +496,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -496,18 +496,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = [] second_per_grid_ts = []
audio_feature_lengths = [] audio_feature_lengths = []
use_audio_in_video = False use_audio_in_video = False
for mm_input in self.requests[req_id].mm_inputs: for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
if mm_input.get("image_grid_thw") is not None: if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend( image_grid_thw.append(
mm_input["image_grid_thw"].tolist()) mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None: if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend( video_grid_thw.append(
mm_input["video_grid_thw"].tolist()) mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None: if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend( second_per_grid_ts.append(
mm_input["second_per_grid_ts"]) mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None: if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.extend( audio_feature_lengths.append(
mm_input["audio_feature_lengths"]) mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True: if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True use_audio_in_video = True
...@@ -624,14 +625,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -624,14 +625,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: # noqa: SIM102 if self.is_multimodal_raw_input_supported: # noqa: SIM102
if scheduler_output: if scheduler_output:
multi_modal_kwargs_list = list[MultiModalKwargs]() mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs: for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs req_mm_kwargs = req.mm_kwargs
if not isinstance(req_mm_inputs, list): if not isinstance(req_mm_kwargs, list):
req_mm_inputs = list(req_mm_inputs) req_mm_kwargs = list(req_mm_kwargs)
multi_modal_kwargs_list.extend(req_mm_inputs) mm_kwargs.extend(req_mm_kwargs)
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return MultiModalKwargs.batch(multi_modal_kwargs_list) return mm_kwargs_combined
return {} return {}
...@@ -1146,13 +1156,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1146,13 +1156,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]() mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]() req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id]) mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append( req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id])) (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
...@@ -1163,17 +1173,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1163,17 +1173,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching # in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
batched_mm_inputs = MultiModalKwargs.batch( mm_kwargs,
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
device=self.device, device=self.device,
) pin_memory=self.pin_memory,
):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size) # 1. A tensor of shape (num_items, feature_size, hidden_size)
...@@ -1182,11 +1187,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1182,11 +1187,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic # (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items. # depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs) **mm_kwargs_group)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
expected_num_items=len(grouped_mm_inputs), expected_num_items=num_items,
) )
for output in curr_group_outputs: for output in curr_group_outputs:
...@@ -1553,17 +1558,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1553,17 +1558,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens] inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) model_kwargs = {
model_kwargs = self._init_model_kwargs(num_scheduled_tokens) **self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
else: else:
# For text-only models, we use token ids as input. # For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the # While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since # multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None inputs_embeds = None
model_mm_kwargs = {} model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens] positions = self.mrope_positions[:, :num_input_tokens]
else: else:
...@@ -1596,10 +1602,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1596,10 +1602,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs, **model_kwargs,
) )
...@@ -2196,14 +2198,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2196,14 +2198,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model # Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * return next(mm_kwargs_group
max_items_per_batch) for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
return MultiModalKwargs.as_kwargs( [dummy_mm_item] * max_items_per_batch,
batched_dummy_mm_inputs,
device=self.device, device=self.device,
) pin_memory=self.pin_memory,
))
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
...@@ -2269,15 +2270,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2269,15 +2270,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.supports_mm_inputs: if self.supports_mm_inputs:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) model_kwargs = {
**self._init_model_kwargs(num_tokens),
**self._dummy_mm_kwargs(num_reqs),
}
else: else:
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
model_mm_kwargs = {} model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens] positions = self.mrope_positions[:, :num_tokens]
...@@ -2307,10 +2310,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2307,10 +2310,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs, **model_kwargs,
) )
......
...@@ -32,9 +32,9 @@ from vllm.model_executor.models.interfaces import supports_transcription ...@@ -32,9 +32,9 @@ from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model) is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
...@@ -394,7 +394,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -394,7 +394,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs, mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
...@@ -842,13 +842,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -842,13 +842,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]() mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]() req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id]) mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append( req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id])) (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
...@@ -859,16 +859,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -859,16 +859,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching # in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) mm_kwargs,
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
device=self.device, device=self.device,
) pin_memory=self.pin_memory,
):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size) # 1. A tensor of shape (num_items, feature_size, hidden_size)
...@@ -878,12 +874,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -878,12 +874,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# depending on the input multimodal items. # depending on the input multimodal items.
xm.mark_step() xm.mark_step()
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs) **mm_kwargs_group)
xm.mark_step() xm.mark_step()
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
expected_num_items=len(grouped_mm_inputs), expected_num_items=num_items,
) )
if isinstance(curr_group_outputs, torch.Tensor): if isinstance(curr_group_outputs, torch.Tensor):
...@@ -1823,14 +1819,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1823,14 +1819,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model # Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * return next(grouped_mm_kwargs
max_items_per_batch) for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
return MultiModalKwargs.as_kwargs( [dummy_mm_item] * max_items_per_batch,
batched_dummy_mm_inputs,
device=self.device, device=self.device,
) pin_memory=self.pin_memory,
))
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment