Unverified Commit ac5bc615 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] MiniCPM-V/O supports V1 (#15487)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8063dfc6
......@@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
* `openbmb/MiniCPM-o-2_6`, etc.
* ✅︎
* ✅︎
*
* ✅︎
- * `MiniCPMV`
* MiniCPM-V
* T + I<sup>E+</sup> + V<sup>E+</sup>
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
* ✅︎
* ✅︎
*
* ✅︎
- * `MllamaForConditionalGeneration`
* Llama 3.2
* T + I<sup>+</sup>
......
This diff is collapsed.
This diff is collapsed.
......@@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
@dataclass
......@@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs(
images=images,
......@@ -1510,11 +1511,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
......@@ -1522,19 +1523,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_masks_flat.unsqueeze(0)),
).squeeze(0)
# Reconstruct the batch dimension
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=images,
image_masks=image_masks,
)
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(image_features, feat_is_patch)
feats[f_is_patch] for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
]
def get_multimodal_embeddings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment