Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3250ff3d
Commit
3250ff3d
authored
Sep 18, 2023
by
Tri Dao
Browse files
Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H)
parent
43617dea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
4 deletions
+20
-4
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+19
-4
tests/test_flash_attn.py
tests/test_flash_attn.py
+1
-0
No files found.
csrc/flash_attn/flash_api.cpp
View file @
3250ff3d
...
@@ -235,13 +235,13 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
...
@@ -235,13 +235,13 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
}
}
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
mha_fwd
(
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
p_dropout
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
is_causal
,
bool
is_causal
,
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -271,8 +271,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
...
@@ -271,8 +271,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const
auto
sizes
=
q
.
sizes
();
const
auto
sizes
=
q
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
batch_size
=
sizes
[
0
];
const
int
seqlen_q
=
sizes
[
1
];
int
seqlen_q
=
sizes
[
1
];
const
int
num_heads
=
sizes
[
2
];
int
num_heads
=
sizes
[
2
];
const
int
head_size_og
=
sizes
[
3
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
k
.
size
(
1
);
const
int
seqlen_k
=
k
.
size
(
1
);
const
int
num_heads_k
=
k
.
size
(
2
);
const
int
num_heads_k
=
k
.
size
(
2
);
...
@@ -280,6 +280,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
...
@@ -280,6 +280,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
const
int
seqlenq_nheads_swapped
=
seqlen_q
==
1
&&
num_heads_k
==
1
&&
num_heads
>
1
and
p_dropout
==
0.
f
and
head_size_og
%
8
==
0
;
if
(
seqlenq_nheads_swapped
)
{
q
=
q
.
transpose
(
1
,
2
);
std
::
swap
(
seqlen_q
,
num_heads
);
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
...
@@ -388,6 +397,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
...
@@ -388,6 +397,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
}
}
if
(
seqlenq_nheads_swapped
)
{
out
=
out
.
transpose
(
1
,
2
);
out_padded
=
out_padded
.
transpose
(
1
,
2
);
q_padded
=
q_padded
.
transpose
(
1
,
2
);
softmax_lse
=
softmax_lse
.
transpose
(
1
,
2
);
}
return
{
out
,
q_padded
,
k_padded
,
v_padded
,
out_padded
,
softmax_lse
,
p
,
rng_state
};
return
{
out
,
q_padded
,
k_padded
,
v_padded
,
out_padded
,
softmax_lse
,
p
,
rng_state
};
}
}
...
...
tests/test_flash_attn.py
View file @
3250ff3d
...
@@ -908,6 +908,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
...
@@ -908,6 +908,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
"seqlen_q,seqlen_k"
,
[
[
(
1
,
147
),
(
113
,
203
),
(
113
,
203
),
(
128
,
217
),
(
128
,
217
),
(
113
,
211
),
(
113
,
211
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment