Unverified Commit 04fb1985 authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #43 from eric-tc-wong/patch-1

Update flash_attention.py
parents 19d12610 b410d14f
...@@ -107,7 +107,7 @@ class FlashMHA(nn.Module): ...@@ -107,7 +107,7 @@ class FlashMHA(nn.Module):
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
h=self.num_heads).unbind(dim=2) h=self.num_heads).unbind(dim=2)
query, key = self.rotary_emb(query, key, seq_dimension=-3) query, key = self.rotary_emb(query, key, seq_dimension=-3)
qkv = torch.stack([query, key, value], dim=2) qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2)
else: else:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment