Unverified Commit 19b927e5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Use individual MM items in P0/P1 cache and model runner (#22570)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 20d65aa7
......@@ -113,6 +113,9 @@ class MsgpackEncoder:
int(v) if v is not None else None
for v in (obj.start, obj.stop, obj.step))
if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
......@@ -120,17 +123,12 @@ class MsgpackEncoder:
return dict(mm)
# ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored.
return [[{
"modality": elem.modality,
"key": elem.key,
"data": self._encode_nested_tensors(elem.data),
"field": self._encode_mm_field(elem.field),
} for elem in item.values()]
for itemlist in mm._items_by_modality.values()
for item in itemlist]
return [
self._encode_mm_item(item)
for itemlist in mm._items_by_modality.values()
for item in itemlist
]
if isinstance(obj, UtilityResult):
result = obj.result
......@@ -192,6 +190,23 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
def _encode_mm_field_elem(self,
elem: MultiModalFieldElem) -> dict[str, Any]:
return {
"modality":
elem.modality,
"key":
elem.key,
"data": (None if elem.data is None else
self._encode_nested_tensors(elem.data)),
"field":
self._encode_mm_field(elem.field),
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt)
......@@ -250,6 +265,8 @@ class MsgpackDecoder:
return self._decode_tensor(obj)
if t is slice:
return slice(*obj)
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
......@@ -311,15 +328,18 @@ class MsgpackDecoder:
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
return [self._decode_mm_item(v) for v in obj]
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems(
[self._decode_mm_field_elem(v) for v in obj])
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
obj["data"] = self._decode_nested_tensors(obj["data"])
def _decode_mm_field_elem(self, obj: dict[str,
Any]) -> MultiModalFieldElem:
if obj["data"] is not None:
obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
......
......@@ -7,9 +7,11 @@ from typing import Optional, cast
import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
......@@ -29,7 +31,7 @@ class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
......@@ -51,6 +53,13 @@ class CachedRequestState:
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
......
......@@ -40,9 +40,9 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
......@@ -478,7 +478,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
......@@ -496,18 +496,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_input in self.requests[req_id].mm_inputs:
for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
image_grid_thw.append(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
video_grid_thw.append(
mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend(
second_per_grid_ts.append(
mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.extend(
audio_feature_lengths.append(
mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
......@@ -624,14 +625,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: # noqa: SIM102
if scheduler_output:
multi_modal_kwargs_list = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs
if not isinstance(req_mm_inputs, list):
req_mm_inputs = list(req_mm_inputs)
multi_modal_kwargs_list.extend(req_mm_inputs)
req_mm_kwargs = req.mm_kwargs
if not isinstance(req_mm_kwargs, list):
req_mm_kwargs = list(req_mm_kwargs)
mm_kwargs.extend(req_mm_kwargs)
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return MultiModalKwargs.batch(multi_modal_kwargs_list)
return mm_kwargs_combined
return {}
......@@ -1146,13 +1156,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
......@@ -1163,17 +1173,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
)
pin_memory=self.pin_memory,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
......@@ -1182,11 +1187,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
expected_num_items=num_items,
)
for output in curr_group_outputs:
......@@ -1553,17 +1558,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
model_kwargs = self._init_model_kwargs(num_scheduled_tokens)
model_kwargs = {
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None
model_mm_kwargs = {}
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
......@@ -1596,10 +1602,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs,
)
......@@ -2196,14 +2198,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_items_per_batch)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
)
return next(mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
))
@torch.inference_mode()
def _dummy_run(
......@@ -2269,15 +2270,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
model_kwargs = {
**self._init_model_kwargs(num_tokens),
**self._dummy_mm_kwargs(num_reqs),
}
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
model_mm_kwargs = {}
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
......@@ -2307,10 +2310,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs,
)
......
......@@ -32,9 +32,9 @@ from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
......@@ -394,7 +394,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=None,
......@@ -842,13 +842,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
......@@ -859,16 +859,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
)
pin_memory=self.pin_memory,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
......@@ -878,12 +874,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# depending on the input multimodal items.
xm.mark_step()
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
**mm_kwargs_group)
xm.mark_step()
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
expected_num_items=num_items,
)
if isinstance(curr_group_outputs, torch.Tensor):
......@@ -1823,14 +1819,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_items_per_batch)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
)
return next(grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
))
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment