Commit 9180c729 authored by mashun1's avatar mashun1
Browse files

Update blocks.py

parent c1676293
...@@ -282,7 +282,7 @@ class MultiHeadCrossAttention(nn.Module): ...@@ -282,7 +282,7 @@ class MultiHeadCrossAttention(nn.Module):
attn_bias = None attn_bias = None
if mask is not None: if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias, op=xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp)
x = x.view(B, -1, C) x = x.view(B, -1, C)
x = self.proj(x) x = self.proj(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment