"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "a8683102cc0ab9c1a0c3ae1ba2b7954f78eba1b3"
Unverified Commit a9e879b3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up MiniCPM-V/O code (#15337)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 3e2f37a6
...@@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str], ...@@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str],
engine_args = EngineArgs( engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf", model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192, max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
......
...@@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = { ...@@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
#### Extended model tests #### Extended model tests
# "aria": VLMTestInfo( "aria": VLMTestInfo(
# models=["rhymes-ai/Aria"], models=["rhymes-ai/Aria"],
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
# prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
# img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n", img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
# max_model_len=4096, max_model_len=4096,
# max_num_seqs=2, max_num_seqs=2,
# auto_cls=AutoModelForImageTextToText, auto_cls=AutoModelForImageTextToText,
# single_image_prompts=IMAGE_ASSETS.prompts({ single_image_prompts=IMAGE_ASSETS.prompts({
# "stop_sign": "<vlm_image>Please describe the image shortly.", "stop_sign": "<vlm_image>Please describe the image shortly.",
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501 "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
# }), }),
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501 multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
# stop_str=["<|im_end|>"], stop_str=["<|im_end|>"],
# image_size_factors=[(0.10, 0.15)], image_size_factors=[(0.10, 0.15)],
# max_tokens=64, max_tokens=64,
# marks=[large_gpu_mark(min_gb=64)], marks=[large_gpu_mark(min_gb=64)],
# ), ),
"blip2": VLMTestInfo( "blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"], models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
...@@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = { ...@@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:", prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:",
num_video_frames=16, num_video_frames=16,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
), ),
...@@ -384,7 +385,18 @@ VLM_TEST_SETTINGS = { ...@@ -384,7 +385,18 @@ VLM_TEST_SETTINGS = {
), ),
"minicpmo_26": VLMTestInfo( "minicpmo_26": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"], models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
),
"minicpmo_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096, max_model_len=4096,
...@@ -392,10 +404,22 @@ VLM_TEST_SETTINGS = { ...@@ -392,10 +404,22 @@ VLM_TEST_SETTINGS = {
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
), ),
"minicpmv_26": VLMTestInfo( "minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"], models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
),
"minicpmv_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096, max_model_len=4096,
...@@ -403,6 +427,7 @@ VLM_TEST_SETTINGS = { ...@@ -403,6 +427,7 @@ VLM_TEST_SETTINGS = {
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
), ),
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
from functools import partial from functools import partial
from typing import Optional, Union from typing import Optional, Union
...@@ -29,7 +28,7 @@ def _test_processing_correctness( ...@@ -29,7 +28,7 @@ def _test_processing_correctness(
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
ignore_mm_keys: Optional[list[str]] = None, ignore_mm_keys: Optional[set[str]] = None,
): ):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
...@@ -145,7 +144,7 @@ def _test_processing_correctness_hf( ...@@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
baseline_processor: BaseMultiModalProcessor, baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor,
batch_idx: int, batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None, ignore_mm_keys: Optional[set[str]] = None,
): ):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token # For some multimodal models, tokenizer will always add bos_token
...@@ -167,11 +166,12 @@ def _test_processing_correctness_hf( ...@@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
baseline_result, baseline_result,
cached_result, cached_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
baseline_tokenized_result = baseline_processor.apply( baseline_tokenized_result = baseline_processor.apply(
token_prompt, token_prompt,
...@@ -179,11 +179,12 @@ def _test_processing_correctness_hf( ...@@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
baseline_result, baseline_result,
baseline_tokenized_result, baseline_tokenized_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
cached_tokenized_result = cached_processor.apply( cached_tokenized_result = cached_processor.apply(
token_prompt, token_prompt,
...@@ -191,11 +192,12 @@ def _test_processing_correctness_hf( ...@@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
cached_result, cached_result,
cached_tokenized_result, cached_tokenized_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
def _test_processing_correctness_mistral( def _test_processing_correctness_mistral(
...@@ -206,7 +208,7 @@ def _test_processing_correctness_mistral( ...@@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
baseline_processor: BaseMultiModalProcessor, baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor,
batch_idx: int, batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None, ignore_mm_keys: Optional[set[str]] = None,
): ):
images = mm_data.get("image", []) images = mm_data.get("image", [])
if not isinstance(images, list): if not isinstance(images, list):
...@@ -233,11 +235,12 @@ def _test_processing_correctness_mistral( ...@@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
baseline_tokenized_result, baseline_tokenized_result,
cached_tokenized_result, cached_tokenized_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
# yapf: disable # yapf: disable
...@@ -261,6 +264,7 @@ def _test_processing_correctness_mistral( ...@@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
"TIGER-Lab/Mantis-8B-siglip-llama3", "TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409", "mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-D-0924",
...@@ -290,7 +294,7 @@ def test_processing_correctness( ...@@ -290,7 +294,7 @@ def test_processing_correctness(
# In Ultravox, the audio_features can be different depending on padding # In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since # The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference. # attention_mask lets us ignore the difference.
ignore_mm_keys = ['audio_features'] ignore_mm_keys = {"audio_features"}
_test_processing_correctness( _test_processing_correctness(
model_id, model_id,
...@@ -328,38 +332,26 @@ def test_processing_correctness_phi3v( ...@@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
) )
def _inputs_equal( def _assert_inputs_equal(
a: MultiModalInputs, a: MultiModalInputs,
b: MultiModalInputs, b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None, *,
ignore_mm_keys: Optional[set[str]] = None,
msg: str = "",
): ):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys( if ignore_mm_keys is None:
b, ignore_mm_keys) ignore_mm_keys = set()
if msg is None:
def _drop_mm_kwargs_keys( assert "mm_kwargs" in a and "mm_kwargs" in b
result: MultiModalInputs, else:
ignore_mm_keys: Optional[list[str]] = None, assert "mm_kwargs" in a and "mm_kwargs" in b, msg
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs']. for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
This is mainly to avoid doing exact match of audio_features in ultravox. b["mm_kwargs"].pop(key, None)
Args: if msg is None:
result: Result to drop keys from assert a == b
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features'] else:
""" assert a == b, msg
if not ignore_mm_keys:
return result
if 'mm_kwargs' in result:
result = copy.deepcopy(result)
mm_kwargs = result['mm_kwargs']
for key in ignore_mm_keys:
mm_kwargs.pop(key, None)
for items in mm_kwargs._items_by_modality.values():
for item in items:
for key in ignore_mm_keys:
item.pop(key, None)
return result
...@@ -295,8 +295,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -295,8 +295,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM # HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None: if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({ parsed_images = (self._get_data_parser().parse_mm_data({
"image": "image":
images images
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
...@@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, ...@@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: torch.Tensor audio_features: torch.Tensor
""" """
Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices, Slice here means chunk. Audio that is too long will be split into slices,
which is the same as image. which is the same as image.
Padding is used therefore `data` is `torch.Tensor`. Padding is used therefore `audio_features` is `torch.Tensor`.
""" """
audio_feature_lens: torch.Tensor audio_feature_lens: torch.Tensor
...@@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict): ...@@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Shape: `(batch_size * num_audios * num_slices)` Shape: `(batch_size * num_audios * num_slices)`
This should be feature length of each audio slice, This should be feature length of each audio slice,
which equals to `data.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
audio_bounds: torch.Tensor audio_bounds: torch.Tensor
...@@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict): ...@@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
data: List[torch.Tensor] audio_embeds: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_slices, hidden_size)` Shape: `(batch_size * num_images * num_slices, hidden_size)`
...@@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, ...@@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
return dict( return dict(
**_minicpmv_field_config(hf_inputs), **_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes( audio_features=MultiModalFieldConfig.batched("audio"),
"audio", audio_num_slices), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes( audio_embeds=MultiModalFieldConfig.batched("audio"),
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
) )
...@@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): ...@@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)" audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_modalities(self) -> List[str]:
return ["image", "video", "audio"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None} return {"image": None, "video": None, "audio": None}
...@@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor( ...@@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data) if (audios := mm_data.get("audios")) is None:
return {}
audios = mm_data.pop("audios", [])
audio_embeds = mm_data.pop("audio_embeds", []) parsed_audios = (self._get_data_parser().parse_mm_data({
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0: "audio": audios
audio_outputs = { }).get_items("audio", AudioProcessorItems))
"audio_lens": [],
"audio_features": [], audio_inputs = self._base_call_hf_processor(
"audio_feature_lens": [], prompts=[self.info.audio_pattern] * len(parsed_audios),
"audio_num_segments": [] mm_data={"audios": [[audio] for audio in parsed_audios]},
} mm_kwargs={
for audio in audios: **mm_kwargs, "chunk_input": True
single_audio_outputs = super().call_base_hf_processor( },
prompt=self.info.audio_pattern, out_keys={"audio_features", "audio_feature_lens"},
mm_data={ )
"audios": audio,
"chunk_input": True # Avoid padding since we need the output for each audio to be
}, # independent of other audios for the cache to work correctly
mm_kwargs=mm_kwargs) unpadded_audio_features = [
audio_outputs["audio_lens"].append(len(audio)) feat[:, :feature_len] for feat, feature_len in zip(
audio_outputs["audio_features"].append( audio_inputs["audio_features"],
single_audio_outputs["audio_features"]) audio_inputs["audio_feature_lens"],
audio_outputs["audio_num_segments"].append( )
len(single_audio_outputs["audio_feature_lens"][0])) ]
audio_outputs["audio_feature_lens"] += \ audio_inputs["audio_features"] = unpadded_audio_features
single_audio_outputs["audio_feature_lens"]
audio_outputs["audio_features"] = [ return audio_inputs
audio_feature for single_audio_features in \
audio_outputs["audio_features"]
for audio_feature in single_audio_features
]
audio_outputs["audio_feature_lens"] = torch.cat(
audio_outputs["audio_feature_lens"])
elif len(audio_embeds):
audio_outputs = {
"audio_lens": [
self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0]
for chunk_embeds in single_audio_embeds))
for single_audio_embeds in audio_embeds
],
"audio_embeds": [
chunk_embeds for single_audio_embeds in audio_embeds
for chunk_embeds in single_audio_embeds
],
"audio_num_segments": [
len(single_audio_embeds)
for single_audio_embeds in audio_embeds
]
}
else:
audio_outputs = {}
return audio_outputs
def get_placeholder_match_pattern(self) -> str: def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)" return r"\(<(image|video|audio)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
def process_mm_inputs( def process_mm_inputs(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]: ) -> Mapping[str, NestedTensors]:
return { return {
"image": self.process_images(mm_data, mm_kwargs), **super().process_mm_inputs(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs), **self.process_audios(mm_data, mm_kwargs),
"audio": self.process_audios(mm_data, mm_kwargs),
} }
def get_modality_num_counter(self, modality: str) -> str:
if modality == "audio":
return "audio_lens"
return super().get_modality_num_counter(modality)
def get_num_slices_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> int:
if modality == "audio":
return inputs["audio"]["audio_num_segments"][index]
return super().get_num_slices_by_modality(inputs, modality, index)
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> str:
if modality == "audio":
return self.get_audio_prompt_texts(
inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6):
# Copied from HF repo of MiniCPM-o-2_6, # Copied from HF repo of MiniCPM-o-2_6,
# designed for batched inputs and outputs # designed for batched inputs and outputs
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor: chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get( wavforms = data.get(
"data", "audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance []) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens", audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]] [])] # list, [[x1, x2], [y1], [z1]]
# exist audio if len(wavforms) == 0:
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
max_seq_len, max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(
audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, attention_mask=audio_attention_mask).hidden_states[
self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = \
self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
else:
return [] return []
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(
audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, attention_mask=audio_attention_mask).hidden_states[
self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = \
self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
audio_inputs: Optional[MiniCPMOAudioInputs], audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor: chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds": if audio_inputs["type"] == "audio_embeds":
audio_embeddings = audio_inputs["data"]
audio_embeddings = [ audio_embeddings = [
audio_embeddings[i].to(device=device, dtype=dtype) item.to(device=device, dtype=dtype)
for i in range(len(audio_embeddings)) for item in audio_inputs["audio_embeds"]
] ]
else: else:
audio_embeddings = self.get_audio_hidden_states( audio_embeddings = self.get_audio_hidden_states(
...@@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6):
def _parse_and_validate_audio_inputs( def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor, self, input_ids: torch.Tensor,
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]: **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", []) audio_features = kwargs.pop("audio_features", None)
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
audio_embeds = kwargs.pop("audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None)
audio_start_id = kwargs.pop("audio_start_id", None)
audio_end_id = kwargs.pop("audio_end_id", None) if audio_features is None and audio_embeds is None:
return None
audio_start_id = kwargs.pop("audio_start_id")
if not isinstance(audio_start_id, torch.Tensor):
raise ValueError("Incorrect type of audio_start_id. "
f"Got type: {type(audio_start_id)}")
audio_end_id = kwargs.pop("audio_end_id")
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
if audio_embeds is not None: if audio_embeds is not None:
audio_embeds = [ if not isinstance(audio_embeds, (torch.Tensor, list)):
audio_embeds[i][j] for i in range(len(audio_embeds)) raise ValueError("Incorrect type of audio_embeds. "
for j in range(len(audio_embeds[i])) f"Got type: {type(audio_embeds)}")
]
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id), audio_end_id),
data=audio_embeds, )
type="audio_embeds")
if len(audio_features) > 0: if audio_features is not None:
audio_features_all = [ if not isinstance(audio_features, (torch.Tensor, list)):
i.permute(1, 0) for audio_feature in audio_features raise ValueError("Incorrect type of audio_features. "
for i in audio_feature f"Got type: {type(audio_features)}")
]
audio_features = torch.nn.utils.rnn.pad_sequence( audio_feature_lens = kwargs.pop("audio_feature_lens")
audio_features_all, batch_first=True, if not isinstance(audio_feature_lens, (torch.Tensor, list)):
padding_value=0.0).permute(0, 2, 1) raise ValueError("Incorrect type of audio_feature_lens. "
audio_feature_lens = torch.cat( f"Got type: {type(audio_feature_lens)}")
[item for item in audio_feature_lens])
return MiniCPMOAudioFeatureInputs( return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id), audio_end_id),
data=audio_features, )
audio_feature_lens=audio_feature_lens,
type="audio_features") raise AssertionError("This line should be unreachable.")
return None
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object): **kwargs: object):
...@@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6):
else: else:
image_inputs, audio_inputs = \ image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs) self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision( vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs) input_ids, image_inputs)
if audio_inputs is not None: if audio_inputs is not None:
......
This diff is collapsed.
...@@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped) return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment