Unverified Commit 962d2c63 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Model][Pixtral] Use memory_efficient_attention for PixtralHFVision (#9520)

parent 5b59fe0f
...@@ -13,8 +13,7 @@ from transformers import PixtralVisionConfig, PretrainedConfig ...@@ -13,8 +13,7 @@ from transformers import PixtralVisionConfig, PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens) _num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import ( from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
generate_block_attention_mask, position_ids_in_meshgrid)
from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -813,48 +812,30 @@ class PixtralHFAttention(nn.Module): ...@@ -813,48 +812,30 @@ class PixtralHFAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor, position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" batch, patches, _ = hidden_states.size()
batch_size, patches, _ = hidden_states.size() q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
query_states = self.q_proj(hidden_states) v = self.v_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
key_states,
cos,
sin,
unsqueeze_dim=0)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) * self.scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # Transpose q and k back for attention
attn_weights = nn.functional.softmax(attn_weights, q = q.transpose(1, 2).contiguous()
dim=-1, k = k.transpose(1, 2).contiguous()
dtype=torch.float32).to( v = v.reshape(batch, patches, self.n_heads, self.head_dim)
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
attn_output = attn_output.reshape(batch_size, patches, -1) out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.o_proj(attn_output) return self.o_proj(out)
class PixtralHFTransformerBlock(nn.Module): class PixtralHFTransformerBlock(nn.Module):
...@@ -869,7 +850,7 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -869,7 +850,7 @@ class PixtralHFTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor, position_embeddings: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states), r = self.attention.forward(self.attention_norm(hidden_states),
...@@ -892,7 +873,7 @@ class PixtralHFTransformer(nn.Module): ...@@ -892,7 +873,7 @@ class PixtralHFTransformer(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor, position_embeddings: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
for layer in self.layers: for layer in self.layers:
...@@ -953,9 +934,8 @@ class PixtralHFVisionModel(nn.Module): ...@@ -953,9 +934,8 @@ class PixtralHFVisionModel(nn.Module):
position_embedding = self.patch_positional_embedding( position_embedding = self.patch_positional_embedding(
patch_embeds, position_ids) patch_embeds, position_ids)
attention_mask = generate_block_attention_mask( attention_mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
patch_embeds)
out = self.transformer(patch_embeds, attention_mask, out = self.transformer(patch_embeds, attention_mask,
position_embedding) position_embedding)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment