Unverified Commit 6c66f28f authored by Wenchen Lo's avatar Wenchen Lo Committed by GitHub
Browse files

Remove xformers requirement for Mistral-format Pixtral and Mistral3 (#21154)


Signed-off-by: default avatarWenchen Lo <charles761013@gmail.com>
parent de509ae8
...@@ -671,7 +671,19 @@ class Attention(nn.Module): ...@@ -671,7 +671,19 @@ class Attention(nn.Module):
v = v.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
if USE_XFORMERS_OPS:
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
else:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=mask)
out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim) out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out) return self.wo(out)
...@@ -814,8 +826,11 @@ class VisionTransformer(nn.Module): ...@@ -814,8 +826,11 @@ class VisionTransformer(nn.Module):
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else: else:
raise ImportError("Xformers is required for Pixtral inference " from transformers.models.pixtral.modeling_pixtral import (
"with the Mistral format") generate_block_attention_mask)
mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# squeeze dim 0 and split into separate tensors for each image # squeeze dim 0 and split into separate tensors for each image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment