Unverified Commit 8693e47e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix `mm_hashes` forgetting to be passed (#15668)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent cec8c7d7
...@@ -528,6 +528,7 @@ class InputPreprocessor: ...@@ -528,6 +528,7 @@ class InputPreprocessor:
prompt_token_ids=decoder_inputs_to_override[ prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"], "prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
else: else:
...@@ -536,6 +537,7 @@ class InputPreprocessor: ...@@ -536,6 +537,7 @@ class InputPreprocessor:
prompt=inputs["prompt"], prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"], prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
elif inputs["type"] == "token": elif inputs["type"] == "token":
......
...@@ -868,6 +868,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -868,6 +868,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"] mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"]
# We reimplement the functionality of MLlavaProcessor from # We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git # https://github.com/TIGER-AI-Lab/Mantis.git
...@@ -916,6 +917,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -916,6 +917,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )
......
...@@ -1378,7 +1378,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1378,7 +1378,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
# Because attn_metadata.encoder_seq_lens only counts the last # Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the # group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only. # block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details. # See MllamaMultiModalProcessor for more details.
num_tiles_tensor = kwargs.pop("num_tiles") num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t.tolist() for t in num_tiles_tensor] num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tokens_per_tile = calc_token_per_chunk(self.image_size) num_tokens_per_tile = calc_token_per_chunk(self.image_size)
......
...@@ -28,7 +28,7 @@ from vllm.model_executor.models.llama import LlamaModel ...@@ -28,7 +28,7 @@ from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalInputs, NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
...@@ -1319,9 +1319,9 @@ def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, ...@@ -1319,9 +1319,9 @@ def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int,
def input_mapper_for_phi4mm_audio(ctx: InputContext, def input_mapper_for_phi4mm_audio(ctx: InputContext,
data: object) -> MultiModalInputs: data: object) -> MultiModalKwargs:
""" """
This function is used to create the MultiModalInputs for the Phi4MM This function is used to create the MultiModalKwargs for the Phi4MM
(audio) model. (audio) model.
Specifically, for audio, we extract the audio features from the sound Specifically, for audio, we extract the audio features from the sound
file and create pairs of audio features and audio embed lengths (the file and create pairs of audio features and audio embed lengths (the
...@@ -1338,13 +1338,13 @@ def input_mapper_for_phi4mm_audio(ctx: InputContext, ...@@ -1338,13 +1338,13 @@ def input_mapper_for_phi4mm_audio(ctx: InputContext,
data (object): Audio data. data (object): Audio data.
Returns: Returns:
MultiModalInputs: Multi-modal inputs. MultiModalKwargs: Multi-modal inputs.
""" """
if not isinstance(data, list): if not isinstance(data, list):
data = [data] data = [data]
if len(data) == 0: if len(data) == 0:
return MultiModalInputs() return MultiModalKwargs()
audio_features = [] audio_features = []
for audio_input in data: for audio_input in data:
...@@ -1365,7 +1365,7 @@ def input_mapper_for_phi4mm_audio(ctx: InputContext, ...@@ -1365,7 +1365,7 @@ def input_mapper_for_phi4mm_audio(ctx: InputContext,
[single_audio_embed_size], [single_audio_embed_size],
) )
audio_features.append(single_audio_feature_audio_len_pair) audio_features.append(single_audio_feature_audio_len_pair)
return MultiModalInputs({"audio_features": audio_features}) return MultiModalKwargs({"audio_features": audio_features})
def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
...@@ -1373,7 +1373,7 @@ def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): ...@@ -1373,7 +1373,7 @@ def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
data = [data] data = [data]
# data: list of PIL images # data: list of PIL images
if len(data) == 0: if len(data) == 0:
return MultiModalInputs() return MultiModalKwargs()
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None: if vision_encoder_name is None:
...@@ -1385,7 +1385,7 @@ def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): ...@@ -1385,7 +1385,7 @@ def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size,
vit_patch_size) vit_patch_size)
return MultiModalInputs({ return MultiModalKwargs({
"pixel_values": "pixel_values":
image_input_dict["pixel_values"], image_input_dict["pixel_values"],
"image_sizes": "image_sizes":
......
...@@ -105,6 +105,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -105,6 +105,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
prompt=prompt, prompt=prompt,
prompt_token_ids=[1], prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs), mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_hashes=None,
mm_placeholders={}, mm_placeholders={},
) )
......
...@@ -743,7 +743,7 @@ class MultiModalInputs(TypedDict): ...@@ -743,7 +743,7 @@ class MultiModalInputs(TypedDict):
mm_kwargs: MultiModalKwargs mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[Optional["MultiModalHashDict"]] mm_hashes: Optional["MultiModalHashDict"]
"""The hashes of the multi-modal data.""" """The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict mm_placeholders: MultiModalPlaceholderDict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment