Unverified Commit 2a00d324 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[CI][MM] Gate vision encoder attention mask to MiniCPM only, fixing Aria regression (#36206)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 10f4db4d
...@@ -359,6 +359,7 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -359,6 +359,7 @@ class Idefics2VisionTransformer(nn.Module):
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool = True, require_post_norm: bool = True,
apply_encoder_attention_mask: bool = False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -366,6 +367,7 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -366,6 +367,7 @@ class Idefics2VisionTransformer(nn.Module):
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.config = config self.config = config
self.use_data_parallel = is_vit_use_data_parallel() self.use_data_parallel = is_vit_use_data_parallel()
self.apply_encoder_attention_mask = apply_encoder_attention_mask
self.embeddings = Idefics2VisionEmbeddings(config) self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder( self.encoder = Idefics2Encoder(
config, config,
...@@ -425,10 +427,16 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -425,10 +427,16 @@ class Idefics2VisionTransformer(nn.Module):
) )
# Align with HuggingFace NaViT SigLIP in MiniCPMV/O: # Align with HuggingFace NaViT SigLIP in MiniCPMV/O:
# - if apply_encoder_attention_mask is False, skip (not all models
# sharing this encoder apply masking in attention, e.g. Aria, Phi4)
# - if patch_attention_mask was None, skip attention masking # - if patch_attention_mask was None, skip attention masking
# - if any padding exists, create an additive 4D mask and pass it # - if any padding exists, create an additive 4D mask and pass it
# to attention; else skip mask for performance. # to attention; else skip mask for performance.
if flat_patch_mask is None or not torch.any(~flat_patch_mask): if (
not self.apply_encoder_attention_mask
or flat_patch_mask is None
or not torch.any(~flat_patch_mask)
):
attention_mask = None attention_mask = None
else: else:
# Additive mask: masked positions receive a large negative value. # Additive mask: masked positions receive a large negative value.
......
...@@ -1336,6 +1336,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1336,6 +1336,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
model = Idefics2VisionTransformer( model = Idefics2VisionTransformer(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
apply_encoder_attention_mask=True,
prefix=prefix, prefix=prefix,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
...@@ -1428,6 +1429,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1428,6 +1429,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
model = Idefics2VisionTransformer( model = Idefics2VisionTransformer(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
apply_encoder_attention_mask=True,
prefix=prefix, prefix=prefix,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
...@@ -1525,6 +1527,7 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1525,6 +1527,7 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
model = Idefics2VisionTransformer( model = Idefics2VisionTransformer(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
apply_encoder_attention_mask=True,
prefix=prefix, prefix=prefix,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
...@@ -1622,6 +1625,7 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1622,6 +1625,7 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
model = Idefics2VisionTransformer( model = Idefics2VisionTransformer(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
apply_encoder_attention_mask=True,
prefix=prefix, prefix=prefix,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment