Unverified Commit ac5bc615 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] MiniCPM-V/O supports V1 (#15487)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8063dfc6
...@@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
* `openbmb/MiniCPM-o-2_6`, etc. * `openbmb/MiniCPM-o-2_6`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `MiniCPMV` - * `MiniCPMV`
* MiniCPM-V * MiniCPM-V
* T + I<sup>E+</sup> + V<sup>E+</sup> * T + I<sup>E+</sup> + V<sup>E+</sup>
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. * `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `MllamaForConditionalGeneration` - * `MllamaForConditionalGeneration`
* Llama 3.2 * Llama 3.2
* T + I<sup>+</sup> * T + I<sup>+</sup>
......
This diff is collapsed.
This diff is collapsed.
...@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict): ...@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size * num_images)`"""
@dataclass @dataclass
...@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch) embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
...@@ -1510,11 +1511,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1510,11 +1511,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch = image_input["feat_is_patch"] feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"] num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once # Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True) images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn( image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True)) image_masks, concat=True))
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone( image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0), images=images_flat.unsqueeze(0),
...@@ -1522,19 +1523,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1522,19 +1523,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_masks_flat.unsqueeze(0)), image_masks_flat.unsqueeze(0)),
).squeeze(0) ).squeeze(0)
# Reconstruct the batch dimension
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=images,
image_masks=image_masks,
)
# Only the features corresponding to patch tokens are relevant # Only the features corresponding to patch tokens are relevant
return [ return [
feats[f_is_patch] feats[f_is_patch] for feats, f_is_patch in zip(
for feats, f_is_patch in zip(image_features, feat_is_patch) image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
] ]
def get_multimodal_embeddings( def get_multimodal_embeddings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment