Unverified Commit 997c8811 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Support multi-image for Molmo (#15438)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent e42389f9
...@@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
* *
- * `MolmoForCausalLM` - * `MolmoForCausalLM`
* Molmo * Molmo
* T + I * T + I<sup>+</sup>
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
......
...@@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = { ...@@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = {
), ),
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=identity, prompt_formatter=identity,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
......
...@@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import select_patch_features from .vision import scatter_patch_features, select_patch_features
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
...@@ -71,13 +71,13 @@ POOLING_SIZE = 2 ...@@ -71,13 +71,13 @@ POOLING_SIZE = 2
class MolmoImageInputs(TypedDict): class MolmoImageInputs(TypedDict):
images: Union[torch.Tensor, List[torch.Tensor]] images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`""" """Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`""" """Shape: `(batch_size, num_crops, num_patch)`"""
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]] feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image features correspond A boolean mask indicating which image features correspond
to patch tokens. to patch tokens.
...@@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict): ...@@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size, num_crops, num_patch)` Shape: `(batch_size, num_crops, num_patch)`
""" """
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]] embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
...@@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict): ...@@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size, num_embeds)` Shape: `(batch_size, num_embeds)`
""" """
num_crops: torch.Tensor num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size, num_images)`"""
...@@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper: ...@@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
image_input_idx = outputs.pop("image_input_idx", None) image_input_idx = outputs.pop("image_input_idx", None)
if image_input_idx is not None: if image_input_idx is not None:
input_is_patch = input_ids == self.image_patch_id feat_is_patch = image_input_idx >= 0
image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
image_valid_flat = image_input_idx_flat >= 0
feat_is_patch_flat = image_valid_flat.clone()
feat_is_patch_flat[image_valid_flat] = (
input_is_patch[image_input_idx_flat[image_valid_flat]])
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
input_is_embed = torch.isin( input_is_embed = torch.isin(
input_ids, input_ids,
...@@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper: ...@@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
embed_is_patch = embed_ids == self.image_patch_id embed_is_patch = embed_ids == self.image_patch_id
assert embed_is_patch.sum() == feat_is_patch.sum() assert embed_is_patch.sum() == feat_is_patch.sum()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
assert len(embed_start) == len(embed_end) == len(images)
embed_is_patch = [
embed_is_patch[start:end + 1]
for start, end in zip(embed_start, embed_end)
]
tilings = [ tilings = [
self.select_tiling( self.select_tiling(
image_width=image.size[0], image_width=image.size[0],
...@@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper: ...@@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
outputs["num_crops"] = num_crops outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id outputs["img_patch_id"] = self.image_patch_id
return BatchFeature(outputs, tensor_type=return_tensors) return BatchFeature(outputs)
class MolmoProcessingInfo(BaseProcessingInfo): class MolmoProcessingInfo(BaseProcessingInfo):
...@@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return MolmoProcessorWrapper(processor) return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache return {"image": None}
# in multi-image case
return {"image": 1}
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
...@@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image", num_crops), "image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes( feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops), "image", num_crops),
embed_is_patch=MultiModalFieldConfig.shared("image", num_images), embed_is_patch=MultiModalFieldConfig.batched("image"),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images), img_patch_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _process_image_input( def _process_image_input(
self, self,
image_input: MolmoImageInputs, image_input: MolmoImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
if isinstance(image_input["images"], list): if isinstance(image_input["images"], list):
# Call the vision backbone on the whole batch at once # Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True) images_flat = flatten_bn(image_input["images"], concat=True)
...@@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,) num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,) embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]: ) -> tuple[torch.Tensor, ...]:
""" """
Scatter the patch features into a contiguous tensor that corresponds Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor. to the embedding tokens defined by the multimodal processor.
...@@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feats_per_image = features.split(num_crops_per_image) feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
_, _, embed_dim = features.shape features = torch.cat([
(num_embeds, ) = embed_is_patch.shape feats[f_is_patch]
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
embeds_in_batch = list[torch.Tensor]() ])
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch return scatter_patch_features(features, embed_is_patch)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
......
...@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs( ...@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features( def scatter_patch_features(
features: torch.Tensor, features: torch.Tensor,
embed_is_patch: torch.Tensor, embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
""" """
Scatter the patch features into a contiguous tensor that corresponds Scatter the patch features into a contiguous tensor that corresponds
...@@ -194,14 +194,19 @@ def scatter_patch_features( ...@@ -194,14 +194,19 @@ def scatter_patch_features(
The resulting embedding tensor is: The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ] [ nan p1 p2 nan p3 p4 nan nan ]
""" """
num_images, num_embeds = embed_is_patch.shape num_embeds_per_image = [
num_embeds_per_image = [num_embeds] * num_images e_is_patch.numel() for e_is_patch in embed_is_patch
]
if isinstance(embed_is_patch, torch.Tensor):
embed_is_patch_flat = embed_is_patch.view(-1)
else:
embed_is_patch_flat = torch.cat(embed_is_patch)
embeds_flat = features.new_full( embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]), (sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan, fill_value=torch.nan,
) )
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2) embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
return embeds_flat.split(num_embeds_per_image) return embeds_flat.split(num_embeds_per_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment