You need to sign in or sign up before continuing.
Commit 19d12610 authored by Tri Dao's avatar Tri Dao
Browse files

Add back need_weights in FlashMHA

parent 6cc73425
...@@ -98,7 +98,7 @@ class FlashMHA(nn.Module): ...@@ -98,7 +98,7 @@ class FlashMHA(nn.Module):
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs) self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, key_padding_mask=None): def forward(self, x, key_padding_mask=None, need_weights=False):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen) key_padding_mask: bool tensor of shape (batch, seqlen)
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment