"docs/vscode:/vscode.git/clone" did not exist on "b0746fae3d576528598d43da129ec5655197560b"
Unverified Commit 5c4f2dd6 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[MM] Pass `prefix` parameter to MMEncoderAttention (#33674)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent f3d8a346
...@@ -127,7 +127,10 @@ class AIMv2Attention(nn.Module): ...@@ -127,7 +127,10 @@ class AIMv2Attention(nn.Module):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
......
...@@ -123,7 +123,10 @@ class BlipAttention(nn.Module): ...@@ -123,7 +123,10 @@ class BlipAttention(nn.Module):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
) )
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
......
...@@ -296,6 +296,7 @@ class Glm4vVisionAttention(nn.Module): ...@@ -296,6 +296,7 @@ class Glm4vVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
prefix=prefix,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
......
...@@ -136,7 +136,10 @@ class EVA2CLIPAttention(nn.Module): ...@@ -136,7 +136,10 @@ class EVA2CLIPAttention(nn.Module):
) )
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_heads_per_rank, self.head_dim, self.scale self.num_heads_per_rank,
self.head_dim,
self.scale,
prefix=prefix,
) )
self.output_dropout = torch.nn.Dropout(config.dropout_prob) self.output_dropout = torch.nn.Dropout(config.dropout_prob)
......
...@@ -163,7 +163,10 @@ class Idefics2VisionAttention(nn.Module): ...@@ -163,7 +163,10 @@ class Idefics2VisionAttention(nn.Module):
) )
# Use unified MMEncoderAttention with Flash Attention support # Use unified MMEncoderAttention with Flash Attention support
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
) )
def forward( def forward(
......
...@@ -212,7 +212,10 @@ class InternParallelAttention(nn.Module): ...@@ -212,7 +212,10 @@ class InternParallelAttention(nn.Module):
) )
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
) )
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
......
...@@ -170,6 +170,7 @@ class InternSdpaAttention(nn.Module): ...@@ -170,6 +170,7 @@ class InternSdpaAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -215,7 +216,12 @@ class InternSdpaAttention(nn.Module): ...@@ -215,7 +216,12 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MMEncoderAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
prefix=prefix,
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x shape: (B, N, C)""" """x shape: (B, N, C)"""
...@@ -313,7 +319,11 @@ class InternS1VisionLayer(nn.Module): ...@@ -313,7 +319,11 @@ class InternS1VisionLayer(nn.Module):
num_dummy_heads: int, num_dummy_heads: int,
prefix: str = "", prefix: str = "",
): ):
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) return InternSdpaAttention(
config,
num_dummy_heads=num_dummy_heads,
prefix=prefix,
)
def forward( def forward(
self, self,
......
...@@ -254,7 +254,10 @@ class Llama4VisionAttention(nn.Module): ...@@ -254,7 +254,10 @@ class Llama4VisionAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_local_heads, self.head_dim, self.scaling self.num_local_heads,
self.head_dim,
self.scaling,
prefix=prefix,
) )
if use_data_parallel: if use_data_parallel:
......
...@@ -231,7 +231,11 @@ class MultiHeadDotProductAttention(nn.Module): ...@@ -231,7 +231,11 @@ class MultiHeadDotProductAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
prefix=prefix,
) )
def forward( def forward(
......
...@@ -611,6 +611,7 @@ class ImagePoolingAttention(nn.Module): ...@@ -611,6 +611,7 @@ class ImagePoolingAttention(nn.Module):
self.head_dim, self.head_dim,
self.scale, self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
prefix=prefix,
) )
def forward_sdpa( def forward_sdpa(
......
...@@ -345,6 +345,7 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -345,6 +345,7 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
prefix=prefix,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
......
...@@ -319,6 +319,7 @@ class Qwen2VisionAttention(nn.Module): ...@@ -319,6 +319,7 @@ class Qwen2VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
prefix=prefix,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
......
...@@ -194,6 +194,7 @@ class Qwen3OmniMoeAudioAttention(nn.Module): ...@@ -194,6 +194,7 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
num_heads=self.num_local_heads, num_heads=self.num_local_heads,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scaling, scale=self.scaling,
prefix=prefix,
) )
def forward( def forward(
......
...@@ -759,7 +759,12 @@ class Step3VisionAttention(nn.Module): ...@@ -759,7 +759,12 @@ class Step3VisionAttention(nn.Module):
) )
# Use unified MMEncoderAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
prefix=prefix,
)
def forward( def forward(
self, self,
......
...@@ -220,7 +220,12 @@ class PerceptionEncoderVisionAttention(nn.Module): ...@@ -220,7 +220,12 @@ class PerceptionEncoderVisionAttention(nn.Module):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
prefix=prefix,
)
self.rope = PerceptionEncoderRope2D( self.rope = PerceptionEncoderRope2D(
dim=self.head_dim, dim=self.head_dim,
max_grid_height=max_grid_height, max_grid_height=max_grid_height,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment