Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
713ea302
"vscode:/vscode.git/clone" did not exist on "aa6e29cf049c0b6cb00f7a39e7011e174c0f98e7"
Commit
713ea302
authored
Aug 05, 2022
by
Tri Dao
Browse files
Allow headdim 128 in FlashMHA interface
parent
2ed471ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
11 deletions
+8
-11
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+8
-11
No files found.
flash_attn/flash_attention.py
View file @
713ea302
...
...
@@ -24,20 +24,16 @@ class FlashAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
def
forward
(
self
,
qkv
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
max_s
=
None
,
need_weights
=
False
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
key_padding_mask: a bool tensor of shape (B, S)
"""
assert
not
need_weights
assert
attn_mask
is
None
assert
qkv
.
dtype
==
torch
.
float16
assert
qkv
.
is_cuda
...
...
@@ -55,10 +51,9 @@ class FlashAttention(nn.Module):
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
else
:
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
nheads
=
qkv
.
shape
[
-
2
]
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask
_bool
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
output_unpad
=
flash_attn_unpadded_qkvpacked_func
(
x_unpad
,
cu_seqlens
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
...
...
@@ -90,7 +85,7 @@ class FlashMHA(nn.Module):
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
assert
self
.
head_dim
in
[
16
,
32
,
64
],
"Only support head_dim == 16, 32,
or 64
"
assert
self
.
head_dim
in
[
16
,
32
,
64
,
128
],
"Only support head_dim == 16, 32,
64, or 128
"
assert
use_rotary_emb
in
[
None
,
'1d'
,
'2d'
]
self
.
use_rotary_emb
=
use_rotary_emb
...
...
@@ -103,8 +98,10 @@ class FlashMHA(nn.Module):
self
.
inner_attn
=
FlashAttention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
need_weights
=
False
):
def
forward
(
self
,
x
,
key_padding_mask
=
None
):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
qkv
=
self
.
Wqkv
(
x
)
if
self
.
use_rotary_emb
:
query
,
key
,
value
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment