Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9eb3d099
Commit
9eb3d099
authored
Apr 07, 2024
by
Tri Dao
Browse files
Transpose out when swapping seqlen_q and num_groups
parent
f692b98d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
4 deletions
+12
-4
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+12
-4
No files found.
csrc/flash_attn/flash_api.cpp
View file @
9eb3d099
...
...
@@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
params
.
num_splits
=
num_splits
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
if
(
num_splits
<
1
)
{
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
*
2
,
num_n_blocks
,
128
);
}
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
max_seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
...
...
@@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
ngroups
=
num_heads
/
num_heads_k
;
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
seqlen_q
=
ngroups
;
num_heads
=
num_heads_k
;
...
...
@@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
out
,
batch_size
,
sizes
[
1
],
sizes
[
2
],
head_size_og
);
if
(
seqlenq_ngroups_swapped
)
{
out
=
out
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
}
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
...
...
@@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
max_seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
ngroups
=
num_heads
/
num_heads_k
;
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
).
reshape
({
batch_size
*
ngroups
,
num_heads_k
,
head_size_og
});
max_seqlen_q
=
ngroups
;
num_heads
=
num_heads_k
;
...
...
@@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
total_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
out
,
sizes
[
0
],
sizes
[
1
],
head_size_og
);
if
(
seqlenq_ngroups_swapped
)
{
out
=
out
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
).
reshape
({
batch_size
*
ngroups
,
num_heads_k
,
head_size_og
});
}
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment