Unverified Commit ac5bc615 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] MiniCPM-V/O supports V1 (#15487)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8063dfc6
...@@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
* `openbmb/MiniCPM-o-2_6`, etc. * `openbmb/MiniCPM-o-2_6`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `MiniCPMV` - * `MiniCPMV`
* MiniCPM-V * MiniCPM-V
* T + I<sup>E+</sup> + V<sup>E+</sup> * T + I<sup>E+</sup> + V<sup>E+</sup>
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. * `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `MllamaForConditionalGeneration` - * `MllamaForConditionalGeneration`
* Llama 3.2 * Llama 3.2
* T + I<sup>+</sup> * T + I<sup>+</sup>
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
TypedDict, Union) Union)
import torch import torch
from torch import nn from torch import nn
...@@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, ...@@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
...@@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, ...@@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix) maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
audio_features: torch.Tensor audio_features: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices, Slice here means chunk. Audio that is too long will be split into slices,
...@@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict): ...@@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Padding is used therefore `audio_features` is `torch.Tensor`. Padding is used therefore `audio_features` is `torch.Tensor`.
""" """
audio_feature_lens: torch.Tensor audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices)` Shape: `(batch_size * num_audios, num_slices)`
This should be feature length of each audio slice, This should be feature length of each audio slice,
which equals to `audio_features.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
audio_bounds: torch.Tensor embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices, 2)` A boolean mask indicating which audio embeddings correspond
to patch tokens.
This should be in `(start, stop)` format. Shape: `(batch_size * num_audios, num_embeds)`
""" """
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
audio_embeds: torch.Tensor audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, hidden_size)` Shape: `(batch_size * num_audios, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor. instead of a batched tensor.
Length of each slice may vary, so pass it as a list. Length of each slice may vary, so pass it as a list.
""" """
audio_bounds: torch.Tensor
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices, 2)` A boolean mask indicating which audio embeddings correspond
to patch tokens.
This should be in `(start, stop)` format. Shape: `(batch_size * num_audios, num_embeds)`
""" """
...@@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, ...@@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features", torch.empty(0))
num_audios = len(audio_features)
return dict( return dict(
**_minicpmv_field_config(hf_inputs), **_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"), audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
) )
...@@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)" audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None} return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
...@@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), **super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio": self.get_max_audio_tokens(), "audio":
"video": self.get_max_video_tokens(seq_len), self.get_max_audio_tokens(),
} }
def get_audio_placeholder(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
hf_processor = self.get_hf_processor()
return hf_processor.get_audio_placeholder(
audio_lens,
chunk_input=chunk_input,
chunk_length=chunk_length,
)
def get_default_audio_pool_step(self) -> int: def get_default_audio_pool_step(self) -> int:
return 2 return 2
...@@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
max_videos = mm_config.get_limit_per_prompt("video") max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.get_limit_per_prompt("audio") max_audios = mm_config.get_limit_per_prompt("audio")
# count <image_idx></image_idx> tokens max_image_tokens = self.get_max_image_tokens() * max_images
# which are not in get_max_image_tokens max_audio_tokens = self.get_max_audio_tokens() * max_audios
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_audio_tokens = self.get_max_audio_tokens(
) * max_audios + 2 * max_audios
max_total_frames = self.get_max_video_frames(seq_len - max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens - max_image_tokens -
max_audio_tokens) max_audio_tokens)
...@@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder( ...@@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
processor_inputs = super().get_dummy_processor_inputs( processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts) seq_len, mm_counts)
mm_data = {
"image": audio_prompt_texts = self.info.audio_pattern * num_audios
processor_inputs.mm_data["image"], audio_mm_data = {
"video":
processor_inputs.mm_data["video"],
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len, num_audios=num_audios)
} }
audio_prompt_texts = self.info.audio_pattern * num_audios return ProcessorInputs(
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \ mm_data={
audio_prompt_texts, **processor_inputs.mm_data,
mm_data=mm_data) **audio_mm_data,
},
)
class MiniCPMOMultiModalProcessor( class MiniCPMOMultiModalProcessor(
...@@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor( ...@@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
return MiniCPMOMultiModalDataParser( return MiniCPMOMultiModalDataParser(
target_sr=self.info.get_default_audio_sampling_rate()) target_sr=self.info.get_default_audio_sampling_rate())
def get_audio_prompt_texts(self, def get_audio_prompt_texts(
self,
audio_lens: int, audio_lens: int,
chunk_input: bool = True, chunk_input: bool = True,
chunk_length: int = 1) -> str: chunk_length: int = 1,
return self.info.get_hf_processor().get_audio_placeholder( ) -> str:
audio_lens, chunk_input, chunk_length) return self.info.get_audio_placeholder(
audio_lens,
def get_special_tokens(self) -> Dict[str, torch.Tensor]: chunk_input=chunk_input,
tokenizer = self.info.get_tokenizer() chunk_length=chunk_length,
special_tokens = super().get_special_tokens() )
if hasattr(tokenizer, "audio_start_id"):
special_tokens["audio_start_id"] = torch.tensor(
tokenizer.audio_start_id)
special_tokens["audio_end_id"] = torch.tensor(
tokenizer.audio_end_id)
return special_tokens
def process_audios( def process_audios(
self, self,
...@@ -274,13 +286,25 @@ class MiniCPMOMultiModalProcessor( ...@@ -274,13 +286,25 @@ class MiniCPMOMultiModalProcessor(
parsed_audios = (self._get_data_parser().parse_mm_data({ parsed_audios = (self._get_data_parser().parse_mm_data({
"audio": audios "audio": audios
}).get_items("audio", AudioProcessorItems)) }).get_items("audio",
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {}
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else:
audio_inputs = self._base_call_hf_processor( audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios), prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]}, mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={ mm_kwargs={
**mm_kwargs, "chunk_input": True **mm_kwargs,
"chunk_input": True,
}, },
out_keys={"audio_features", "audio_feature_lens"}, out_keys={"audio_features", "audio_feature_lens"},
) )
...@@ -295,10 +319,31 @@ class MiniCPMOMultiModalProcessor( ...@@ -295,10 +319,31 @@ class MiniCPMOMultiModalProcessor(
] ]
audio_inputs["audio_features"] = unpadded_audio_features audio_inputs["audio_features"] = unpadded_audio_features
return audio_inputs audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
def get_placeholder_match_pattern(self) -> str: audio_repl_features = [
return r"\(<(image|video|audio)>./</\1>\)" self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
return audio_inputs
def process_mm_inputs( def process_mm_inputs(
self, self,
...@@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor( ...@@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
if isinstance(audios, MiniCPMOAudioEmbeddingItems): if isinstance(audios, MiniCPMOAudioEmbeddingItems):
single_audio_embeds = audios.get(item_idx)["audio_embeds"] single_audio_embeds = audios.get(item_idx)["audio_embeds"]
audio_len = self.info.get_audio_len_by_num_chunks( audio_len = self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0] sum(map(len, single_audio_embeds)))
for chunk_embeds in single_audio_embeds))
else: else:
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
...@@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
self.apm = self.init_audio_module(vllm_config=vllm_config, self.apm = self.init_audio_module(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "apm")) prefix=maybe_prefix(prefix, "apm"))
self.audio_token_id = None
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily # Do not use parameters temporarily
audio_config = self.config.audio_config audio_config = self.config.audio_config
...@@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
return input_lengths_after_cnn, input_lengths_after_pooling return input_lengths_after_cnn, input_lengths_after_pooling
# Copied from HF repo of MiniCPM-o-2_6, def get_audio_hidden_states(
# designed for batched inputs and outputs self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, chunk_length = self.config.audio_chunk_length
chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get( # (bs, 80, frames) or [], multi audios need filled in advance
"audio_features", wavforms_raw = data["audio_features"]
[]) # (bs, 80, frames) or [], multi audios need filled in advance if isinstance(wavforms_raw, list):
audio_feature_lens_raw = [data.get("audio_feature_lens", B = len(wavforms_raw)
[])] # list, [[x1, x2], [y1], [z1]] C = wavforms_raw[0].shape[-2]
L = max(item.shape[-1] for item in wavforms_raw)
device = wavforms_raw[0].device
dtype = wavforms_raw[0].dtype
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
for i, wavforms_item in enumerate(wavforms_raw):
L_item = wavforms_item.shape[-1]
wavforms[i, ..., :L_item] = wavforms_item
else:
wavforms = wavforms_raw
if len(wavforms) == 0: # list, [[x1, x2], [y1], [z1]]
return [] audio_feature_lens_raw = data["audio_feature_lens"]
if isinstance(audio_feature_lens_raw, torch.Tensor):
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
audio_feature_lens = torch.hstack(audio_feature_lens_raw) audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape batch_size, _, max_mel_seq_len = wavforms.shape
...@@ -625,100 +683,52 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -625,100 +683,52 @@ class MiniCPMO(MiniCPMV2_6):
num_audio_tokens = feature_lens_after_pooling num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = [] final_audio_embeds = list[torch.Tensor]()
idx = 0 idx = 0
for i in range(len(audio_feature_lens_raw)): for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = [] target_audio_embeds_lst = list[torch.Tensor]()
for _ in range(len(audio_feature_lens_raw[i])): for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append( target_audio_embeds_lst.append(
audio_embeds[idx, :num_audio_tokens[idx], :]) audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1 idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor: return final_audio_embeds
device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = [
item.to(device=device, dtype=dtype)
for item in audio_inputs["audio_embeds"]
]
else:
audio_embeddings = self.get_audio_hidden_states(
audio_inputs, chunk_length)[0]
if audio_embeddings is None or len(audio_embeddings) == 0:
return vlm_embedding
audio_bounds = audio_inputs["audio_bounds"]
if self.config.chunk_input:
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
dtype=dtype)
audio_start_pos = 0
for bound in audio_bounds:
audio_len = bound[1] - bound[0]
vlm_embedding[bound[0]:bound[1]] = audio_embs[
audio_start_pos:audio_start_pos + audio_len, :]
audio_start_pos += audio_len
else:
for embs, bound in zip(audio_embeddings, audio_bounds):
audio_indices = torch.arange(bound[0],
bound[1],
dtype=torch.long).to(device)
if embs.shape[0] != len(audio_indices):
raise ValueError(
"Shape mismatch: Trying to assign embeddings "
f"of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}")
vlm_embedding[audio_indices] = embs.to(dtype)
return vlm_embedding
def _get_audio_bounds(self, input_ids: torch.Tensor,
audio_start_id: torch.Tensor,
audio_end_id: torch.Tensor) -> torch.Tensor:
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
audio_start_tokens += 1
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
return torch.hstack([
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
])
def _parse_and_validate_audio_inputs( def _parse_and_validate_audio_input(
self, input_ids: torch.Tensor, self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", None) audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None: if audio_features is None and audio_embeds is None:
return None return None
audio_start_id = kwargs.pop("audio_start_id") audio_token_id = kwargs.pop("audio_token_id")
if not isinstance(audio_start_id, torch.Tensor): if audio_token_id is not None:
raise ValueError("Incorrect type of audio_start_id. " assert isinstance(audio_token_id, torch.Tensor)
f"Got type: {type(audio_start_id)}") self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_end_id = kwargs.pop("audio_end_id") audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)): if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. " raise ValueError("Incorrect type of audio_embeds. "
f"Got type: {type(audio_embeds)}") f"Got type: {type(audio_embeds)}")
audio_embeds_flat = flatten_bn(audio_embeds)
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds", type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds), audio_embeds=audio_embeds_flat,
concat=True), embed_is_patch=audio_embed_is_patch,
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
) )
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)): if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. " raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}") f"Got type: {type(audio_features)}")
...@@ -728,56 +738,49 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -728,56 +738,49 @@ class MiniCPMO(MiniCPMV2_6):
raise ValueError("Incorrect type of audio_feature_lens. " raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}") f"Got type: {type(audio_feature_lens)}")
audio_features_flat = flatten_bn(audio_features)
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
return MiniCPMOAudioFeatureInputs( return MiniCPMOAudioFeatureInputs(
type="audio_features", type="audio_features",
audio_features=flatten_bn(audio_features, concat=True), audio_features=audio_features_flat,
audio_feature_lens=flatten_bn( audio_feature_lens=audio_feature_lens_flat,
flatten_2d_lists(audio_feature_lens), concat=True), embed_is_patch=audio_embed_is_patch,
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
) )
raise AssertionError("This line should be unreachable.") def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, # Preserve the order of modalities if there are multiple of them
**kwargs: object): # from the order of kwargs.
image_inputs = self._parse_and_validate_image_inputs( for input_key in kwargs:
input_ids, **kwargs) if input_key in ("audio_features",
if not any("audio" in key for key in kwargs): "audio_embeds") and "audios" not in modalities:
return image_inputs, None modalities["audios"] = self._parse_and_validate_audio_input(
audio_inputs = self._parse_and_validate_audio_inputs( **kwargs)
input_ids, **kwargs)
return image_inputs, audio_inputs
def forward( return modalities
def _process_audio_input(
self, self,
input_ids: torch.Tensor, audio_input: MiniCPMOAudioInputs,
positions: torch.Tensor, ) -> Union[torch.Tensor, list[torch.Tensor]]:
intermediate_tensors: Optional[IntermediateTensors] = None, if audio_input["type"] == "audio_embeds":
**kwargs: Any, return audio_input["audio_embeds"]
) -> torch.Tensor:
if intermediate_tensors is not None: return self.get_audio_hidden_states(audio_input)
vlm_embeddings = None
else: def _process_multimodal_inputs(self, modalities: dict):
image_inputs, audio_inputs = \ multimodal_embeddings = super()._process_multimodal_inputs(modalities)
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings = self.get_embedding_with_vision( for modality in modalities:
input_ids, image_inputs) if modality == "audios":
audio_input = modalities["audios"]
if audio_inputs is not None: audio_features = self._process_audio_input(audio_input)
vlm_embeddings = self.get_embedding_with_audios( multimodal_embeddings += tuple(
vlm_embeddings, audio_inputs, scatter_patch_features(
self.config.audio_chunk_length) audio_features,
audio_input["embed_is_patch"],
# always pass the input via `inputs_embeds` ))
# to make sure the computation graph is consistent
# for `torch.compile` integration return multimodal_embeddings
input_ids = None
output = self.llm.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
)
return output
...@@ -23,17 +23,15 @@ ...@@ -23,17 +23,15 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math import math
import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
TypedDict, Union) Union)
import numpy as np import numpy as np
import torch import torch
import torch.types import torch.types
from PIL import Image
from torch import nn from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys ...@@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
MultiModalInputs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize, ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems, ModalityData, ModalityDataItems,
...@@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors ...@@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists from vllm.utils import flatten_2d_lists
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsV0Only) SupportsMultiModal, SupportsPP)
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
CPU_DEVICE = torch.device("cpu") from .vision import scatter_patch_features, select_patch_features
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
...@@ -86,37 +80,41 @@ class MiniCPMVImagePixelInputs(TypedDict): ...@@ -86,37 +80,41 @@ class MiniCPMVImagePixelInputs(TypedDict):
instead of a batched tensor. instead of a batched tensor.
""" """
image_bounds: torch.Tensor tgt_sizes: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_slices, 2)` Shape: `(batch_size * num_images * num_slices, 2)`
This should be in `(start, stop)` format. This should be in `(height, width)` format.
""" """
tgt_sizes: torch.Tensor embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, 2)` A boolean mask indicating which image embeddings correspond
to patch tokens.
This should be in `(height, width)` format. Shape: `(batch_size * num_images, num_embeds)`
""" """
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class MiniCPMVImageEmbeddingInputs(TypedDict): class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: torch.Tensor image_embeds: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, Shape: `(batch_size * num_images, num_slices, hidden_size)`
image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor. instead of a batched tensor.
""" """
image_bounds: torch.Tensor embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, 2)` A boolean mask indicating which image embeddings correspond
to patch tokens.
This should be in `(start, stop)` format. Shape: `(batch_size * num_images, num_embeds)`
""" """
...@@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: ...@@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
num_images = len(pixel_values)
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
num_videos = len(video_pixel_values)
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"), tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"), video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"),
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
) )
...@@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return get_version_by_config(self.get_hf_config()) return get_version_by_config(self.get_hf_config())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
mm_limits = {"image": None}
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
return {"image": None, "video": None} mm_limits["video"] = None
else:
return {"image": None} return mm_limits
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
...@@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
mm_max_tokens = {"image": self.get_max_image_tokens()} mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
return mm_max_tokens return mm_max_tokens
def get_max_video_frame_tokens(self) -> int: def get_slice_image_placeholder(
frame_size = self.get_video_frame_size_with_most_features() self,
return self.get_num_image_tokens(frame_size, image_size: ImageSize,
self.get_video_max_slice_num()) # For MiniCPM V/O 2.6
image_idx: int = 0,
max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> str:
image_processor = self.get_image_processor()
version = self.get_model_version()
def get_max_video_tokens(self, seq_len: int) -> int: if version == (2, 0) or version == (2, 5):
return self.get_max_video_frame_tokens( return image_processor.get_slice_image_placeholder(image_size)
) * self.get_num_frames_with_most_features(seq_len)
def get_slice_query_num(self) -> int: return image_processor.get_slice_image_placeholder(
hf_config = self.get_hf_config() image_size,
query_num = getattr(hf_config, "query_num", 64) image_idx=image_idx,
return query_num max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
def get_max_slice_num(self) -> int: def get_num_image_tokens(
hf_config = self.get_hf_config() self,
max_slice_num = getattr(hf_config, "max_slice_num", 9) image_size: ImageSize,
return max_slice_num max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> int:
tokenizer = self.get_tokenizer()
image_placeholders = self.get_slice_image_placeholder(
image_size,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
image_token_ids = tokenizer.encode(image_placeholders,
add_special_tokens=False)
def get_sliced_grid(self, image_size: ImageSize, return len(image_token_ids)
max_slice_num: int) -> Tuple[int, int]:
if self.get_model_version() == (2, 6):
slice_grid = self.get_image_processor().get_sliced_grid(
image_size, max_slice_num)
else:
slice_grid = self.get_image_processor().get_sliced_grid(image_size)
return slice_grid
def get_num_image_tokens(self, image_size: ImageSize,
max_slice_num: int) -> int:
slice_grid = self.get_sliced_grid(image_size, max_slice_num)
num_tokens = self.get_slice_query_num(
) + 2 # <image>(<unk> * query_num)</image>
if slice_grid is not None:
if self.get_model_version() == (2, 6):
num_additional_tokens = 0
else:
# <slice><image>(<unk> * query_num)</image></slice>
num_additional_tokens = 2
num_tokens += ((self.get_slice_query_num() + 2) \
* slice_grid[0] * slice_grid[1]) \
+ slice_grid[1] - 1 + num_additional_tokens
return num_tokens
def get_image_slice_nums(self, image_size: torch.Tensor,
max_slice_nums: int) -> int:
grid = self.get_sliced_grid(image_size, max_slice_nums)
return 1 if grid is None else grid[0] * grid[1] + 1
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features() image_size = self.get_image_size_with_most_features()
return self.get_num_image_tokens(image_size, self.get_max_slice_num()) return self.get_num_image_tokens(image_size)
def get_image_max_slice_num(self) -> int:
return getattr(self.get_hf_config(), "max_slice_num", 9)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 9:1) image_size = getattr(self.get_hf_config(), "image_size", 448)
return self.get_default_image_sizes(self.get_max_slice_num()) max_slice_num = self.get_image_max_slice_num()
return ImageSize(width=image_size, height=image_size * max_slice_num)
def get_max_video_frame_tokens(self) -> int:
frame_size = self.get_video_frame_size_with_most_features()
return self.get_num_image_tokens(
frame_size,
max_slice_nums=self.get_video_max_slice_num(),
use_image_id=False,
)
def get_max_video_tokens(self, seq_len: int) -> int:
return self.get_max_video_frame_tokens(
) * self.get_num_frames_with_most_features(seq_len)
def get_video_max_slice_num(self) -> int: def get_video_max_slice_num(self) -> int:
return 1 return 1
def get_video_frame_size_with_most_features(self) -> ImageSize: def get_video_frame_size_with_most_features(self) -> ImageSize:
return self.get_default_image_sizes(self.get_video_max_slice_num()) image_size = getattr(self.get_hf_config(), "image_size", 448)
max_slice_num = self.get_video_max_slice_num()
return ImageSize(width=image_size, height=image_size * max_slice_num)
def get_max_video_frames(self, max_tokens: int) -> int: def get_max_video_frames(self, max_tokens: int) -> int:
num_frame_tokens = self.get_max_video_frame_tokens() num_frame_tokens = self.get_max_video_frame_tokens()
...@@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
max_images = mm_config.get_limit_per_prompt("image") max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video") max_videos = mm_config.get_limit_per_prompt("video")
# count <image_idx></image_idx> tokens max_image_tokens = self.get_max_image_tokens() * max_images
# which are not in get_max_image_tokens
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_total_frames = self.get_max_video_frames(seq_len - max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
...@@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return num_frames return num_frames
def get_default_image_sizes(self, num_slices: int) -> ImageSize:
image_size = getattr(self.get_hf_config(), "image_size", 448)
return ImageSize(width=image_size, height=image_size * num_slices)
_I = TypeVar("_I", _I = TypeVar("_I",
bound=MiniCPMVProcessingInfo, bound=MiniCPMVProcessingInfo,
...@@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser() return MiniCPMVMultiModalDataParser()
def get_slice_image_placeholder(self, image_size: ImageSize,
**kwargs) -> str:
image_processor = self.info.get_image_processor()
version = self.info.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size, **kwargs)
def get_image_prompt_texts(self, def get_image_prompt_texts(self,
image_size: ImageSize, image_size: ImageSize,
image_idx: int = 0) -> str: image_idx: int = 0) -> str:
return self.get_slice_image_placeholder(image_size, return self.info.get_slice_image_placeholder(
image_idx=image_idx) image_size,
image_idx=image_idx,
)
def get_video_prompt_texts(self, image_size: ImageSize, def get_video_prompt_texts(self, image_size: ImageSize,
num_frames: int) -> str: num_frames: int) -> str:
return self.get_slice_image_placeholder( return self.info.get_slice_image_placeholder(
image_size=image_size, image_size=image_size,
image_idx=0, image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(), max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False, use_image_id=False,
) * num_frames ) * num_frames
def get_special_tokens(self) -> Dict[str, torch.Tensor]: def get_embed_is_patch(
self,
input_ids: list[int],
) -> torch.Tensor:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
special_tokens = { return torch.tensor(input_ids) == unk_token_id
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
}
if hasattr(tokenizer, "slice_start_id"):
special_tokens["slice_start_id"] = tokenizer.slice_start_id
special_tokens["slice_end_id"] = tokenizer.slice_end_id
return {k: torch.tensor(v) for k, v in special_tokens.items()}
def process_images( def process_images(
self, self,
...@@ -546,15 +545,44 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -546,15 +545,44 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
parsed_images = (self._get_data_parser().parse_mm_data({ parsed_images = (self._get_data_parser().parse_mm_data({
"image": images "image": images
}).get_items("image", ImageProcessorItems)) }).get_items("image",
(MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
return self._base_call_hf_processor( if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
image_inputs = {}
else:
image_inputs = self._base_call_hf_processor(
prompts=[self.info.image_pattern] * len(parsed_images), prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]}, mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_repl_features = [
self.get_image_prompt_texts(size, idx)
for idx, size in enumerate(image_sizes)
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(image_repl_tokens)
for image_repl_tokens in image_repls_feature_tokens
]
image_inputs["embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
return image_inputs
def process_videos( def process_videos(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
...@@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
parsed_videos = (self._get_data_parser().parse_mm_data({ parsed_videos = (self._get_data_parser().parse_mm_data({
"video": videos "video": videos
}).get_items("video", VideoProcessorItems)) }).get_items("video",
(MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))
max_slice_num = self.info.get_video_max_slice_num()
if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
video_inputs = {}
else:
video_inputs = self._base_call_hf_processor( video_inputs = self._base_call_hf_processor(
prompts=[ prompts=[
self.info.image_pattern * len(video) for video in parsed_videos self.info.image_pattern * len(video)
for video in parsed_videos
], ],
mm_data={"images": list(parsed_videos)}, mm_data={"images": list(parsed_videos)},
mm_kwargs={ mm_kwargs={
**mm_kwargs, "max_slice_nums": max_slice_num **mm_kwargs,
"max_slice_nums":
self.info.get_video_max_slice_num(),
}, },
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
return {f"video_{k}": v for k, v in video_inputs.items()} frame_sizes = [
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
]
num_frames = [
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
]
video_repl_features = [
self.get_video_prompt_texts(size, nframes)
for size, nframes in zip(frame_sizes, num_frames)
]
def get_placeholder_match_pattern(self) -> str: tokenizer = self.info.get_tokenizer()
return r"\(<(image|video)>./</\1>\)" video_repls_feature_tokens = [
tokenizer.encode(video_repl, add_special_tokens=False)
for video_repl in video_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(video_repl_tokens)
for video_repl_tokens in video_repls_feature_tokens
]
video_inputs["embed_is_patch"] = embed_is_patch
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
return video_inputs
def process_mm_inputs( def process_mm_inputs(
self, self,
...@@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
*, *,
out_keys: set[str], out_keys: set[str],
) -> Mapping[str, NestedTensors]: ) -> dict[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together # This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6): if self.info.get_model_version() == (2, 6):
inputs = super()._call_hf_processor( inputs = super()._call_hf_processor(
...@@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
input_ids = torch.tensor([tokenizer.encode(prompt)])
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
return BatchFeature({ return BatchFeature({
"input_ids": "input_ids": input_ids,
torch.tensor([tokenizer.encode(prompt)]),
**mm_inputs, **mm_inputs,
}) })
...@@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return _minicpmv_field_config(hf_inputs) return _minicpmv_field_config(hf_inputs)
def apply(
self,
prompt: Union[str, List[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
if isinstance(prompt, list):
prompt = self.info.get_tokenizer().decode(prompt)
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
mm_orders = {
f"{modality}_orders":
torch.tensor(
[index for index, m in enumerate(matches) if m == modality])
for modality in self.info.get_supported_mm_limits()
}
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
# Exclude <image_id>x</image_id> from placeholders
if "image" in result["mm_placeholders"] and \
self.info.get_model_version() == (2, 6):
result["mm_placeholders"]["image"] = [
PlaceholderRange(offset=p["offset"] + 3 + idx // 10,
length=p["length"] - 3 - idx // 10)
for idx, p in enumerate(result["mm_placeholders"]["image"])
]
result["mm_kwargs"].update(**mm_orders)
result["mm_kwargs"].update(**self.get_special_tokens())
return result
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
SupportsV0Only):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
...@@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "resampler")) prefix, "resampler"))
self.mm_token_ids = set[int]()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors) self.llm.make_empty_intermediate_tensors)
...@@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
return get_sampler() return get_sampler()
def get_embedding_with_vision( def _parse_and_validate_vision_input(
self, self,
input_ids: torch.Tensor, modality: str,
image_inputs: Optional[MiniCPMVImageInputs], **kwargs: object,
) -> torch.Tensor: ) -> Optional[MiniCPMVImageInputs]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_inputs is None: if pixel_values is None and image_embeds is None:
return vlm_embedding return None
if image_inputs["type"] == "image_embeds": image_token_id = kwargs.pop("image_token_id")
vision_hidden_states = image_inputs["image_embeds"].to( if image_token_id is not None:
device=vlm_embedding.device, assert isinstance(image_token_id, torch.Tensor)
dtype=vlm_embedding.dtype, self.mm_token_ids.add(image_token_id.flatten().unique().item())
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of embed_is_patch for {modality=}. "
f"Got type: {type(embed_is_patch)}")
def _get_image_bounds( embed_is_patch = flatten_bn(embed_is_patch)
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])
image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
image_end_tokens, = torch.where(end_cond)
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
return torch.hstack([
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
])
def _parse_and_validate_image_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
image_keys = {"pixel_values", "tgt_sizes"}
pixel_data = {
"image": {
key: kwargs.pop(key, None)
for key in image_keys
},
"video": {
key: kwargs.pop("video_" + key, None)
for key in image_keys
}
}
embed_data = {
"image": kwargs.pop("image_embeds", None),
"video": kwargs.pop("video_embeds", None),
}
all_pixel_data = [ if image_embeds is not None:
v for vs in pixel_data.values() for v in vs.values() if not isinstance(image_embeds, (torch.Tensor, list)):
if v is not None raise ValueError(
] f"Incorrect type of image_embeds for {modality=}. "
all_embed_data = [v for v in embed_data.values() if v is not None] f"Got type: {type(image_embeds)}")
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
return None
im_start_id = kwargs.pop("im_start_id") image_embeds_flat = flatten_bn(image_embeds)
if not isinstance(im_start_id, torch.Tensor):
raise ValueError("Incorrect type of im_start_id. "
f"Got type: {type(im_start_id)}")
im_end_id = kwargs.pop("im_end_id")
if not isinstance(im_end_id, torch.Tensor):
raise ValueError("Incorrect type of im_end_id. "
f"Got type: {type(im_end_id)}")
slice_start_id = kwargs.pop("slice_start_id", None)
if slice_start_id is not None and not isinstance(
slice_start_id, torch.Tensor):
raise ValueError("Incorrect type of slice_start_id. "
f"Got type: {type(slice_start_id)}")
slice_end_id = kwargs.pop("slice_end_id", None)
if slice_end_id is not None and not isinstance(slice_end_id,
torch.Tensor):
raise ValueError("Incorrect type of slice_end_id. "
f"Got type: {type(slice_end_id)}")
if len(all_embed_data) > 0:
if len(all_embed_data) > 1:
raise ValueError("Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not "
"exist simultaneously.")
vision_embeds, = all_embed_data
if not isinstance(vision_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of vision_embeds. "
f"Got type: {type(vision_embeds)}")
return MiniCPMVImageEmbeddingInputs( return MiniCPMVImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds), image_embeds=image_embeds_flat,
concat=True), embed_is_patch=embed_is_patch,
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
) )
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]() if not isinstance(pixel_values, (torch.Tensor, list)):
for modality in ("image", "video"):
modality_orders = kwargs.pop(f"{modality}_orders", None)
if modality_orders is not None:
if not isinstance(modality_orders, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {modality}_orders. "
f"Got type: {type(modality_orders)}")
order_data[modality] = modality_orders
batch_sizes = {
modality: len(modality_orders)
for modality, modality_orders in order_data.items()
}
unique_batch_sizes = set(batch_sizes.values())
assert len(unique_batch_sizes) == 1, (
f"Found inconsistent batch sizes: {batch_sizes}")
batch_size, = unique_batch_sizes
pixel_values_flat = list[torch.Tensor]()
tgt_sizes_flat = list[torch.Tensor]()
for b in range(batch_size):
mm_orders_b = [(idx_b.item(), modality)
for modality, modality_orders in order_data.items()
for idx_b in modality_orders[b]]
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
modality_pixel_data = pixel_data[modality]
modality_pixel_values = modality_pixel_data["pixel_values"]
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
raise ValueError( raise ValueError(
f"Incorrect type of pixel_values for {modality=}. " f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(modality_pixel_values)}") f"Got type: {type(pixel_values)}")
modality_tgt_sizes = modality_pixel_data["tgt_sizes"] tgt_sizes = kwargs.pop("tgt_sizes")
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)): if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError( raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
f"Incorrect type of tgt_sizes for {modality=}. " f"Got type: {type(tgt_sizes)}")
f"Got type: {type(modality_tgt_sizes)}")
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b]) num_slices = [[len(p) for p in ps] for ps in pixel_values]
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b]) num_slices_flat = flatten_bn(torch.tensor(num_slices))
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat): if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError("Inconsistent flattened lengths, found: " raise ValueError("Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. " f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}") f"{len(tgt_sizes_flat)}")
if len(pixel_values_flat) == 0:
return None
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=pixel_values_flat, pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=tgt_sizes_flat,
image_bounds=self._get_image_bounds(input_ids, im_start_id, embed_is_patch=embed_is_patch,
im_end_id, slice_start_id, num_slices=num_slices_flat,
slice_end_id),
) )
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
**kwargs: object): modalities = {}
return self._parse_and_validate_image_inputs(input_ids, **kwargs)
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_vision_input(
"images", **kwargs)
if input_key in ("video_pixel_values",
"video_embeds") and "videos" not in modalities:
def _image_key(video_key: str):
if video_key == "video_token_id":
return "image_token_id"
return video_key.removeprefix("video_")
modalities["videos"] = self._parse_and_validate_vision_input(
"videos", **{
_image_key(k): v
for k, v in kwargs.items()
})
return modalities
def _process_vision_input(
self,
image_input: MiniCPMVImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
image_features_flat = self.get_vision_hidden_states(image_input)
# Reconstruct the batch dimension
return image_features_flat.split(image_input["num_slices"].tolist())
def _process_multimodal_inputs(self, modalities: dict):
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_features = self._process_vision_input(image_input)
multimodal_embeddings += tuple(
scatter_patch_features(
image_features,
image_input["embed_is_patch"],
))
if modality == "videos":
video_input = modalities["videos"]
video_features = self._process_vision_input(video_input)
multimodal_embeddings += tuple(
scatter_patch_features(
video_features,
video_input["embed_is_patch"],
))
return multimodal_embeddings
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return self._process_multimodal_inputs(modalities)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
list(self.mm_token_ids),
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
if intermediate_tensors is not None: if intermediate_tensors is not None:
vlm_embeddings = None inputs_embeds = None
else:
image_inputs = \ # NOTE: In v1, inputs_embeds is always generated at model runner from
self._parse_and_validate_inputs(input_ids, **kwargs) # `get_multimodal_embeddings` and `get_input_embeddings`, this
vlm_embeddings = self.get_embedding_with_vision( # condition is only for v0 compatibility.
input_ids, image_inputs) elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent inputs_embeds = self.get_input_embeddings(input_ids,
# for `torch.compile` integration vision_embeddings)
input_ids = None input_ids = None
output = self.llm.model( hidden_states = self.llm.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings, inputs_embeds=inputs_embeds,
) )
return output return hidden_states
def compute_logits( def compute_logits(
self, self,
...@@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return model return model
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self, def init_resampler(self,
embed_dim: int, embed_dim: int,
vision_dim: int, vision_dim: int,
......
...@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict): ...@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size * num_images)`"""
@dataclass @dataclass
...@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch) embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
...@@ -1510,11 +1511,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1510,11 +1511,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch = image_input["feat_is_patch"] feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"] num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once # Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True) images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn( image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True)) image_masks, concat=True))
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone( image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0), images=images_flat.unsqueeze(0),
...@@ -1522,19 +1523,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1522,19 +1523,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_masks_flat.unsqueeze(0)), image_masks_flat.unsqueeze(0)),
).squeeze(0) ).squeeze(0)
# Reconstruct the batch dimension
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=images,
image_masks=image_masks,
)
# Only the features corresponding to patch tokens are relevant # Only the features corresponding to patch tokens are relevant
return [ return [
feats[f_is_patch] feats[f_is_patch] for feats, f_is_patch in zip(
for feats, f_is_patch in zip(image_features, feat_is_patch) image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
] ]
def get_multimodal_embeddings( def get_multimodal_embeddings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment