Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
2d8ea9a5
Commit
2d8ea9a5
authored
Sep 20, 2023
by
Tri Dao
Browse files
Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza)
parent
0705d271
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
53 deletions
+62
-53
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+29
-26
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+0
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+17
-13
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+14
-10
tests/test_flash_attn.py
tests/test_flash_attn.py
+2
-2
No files found.
csrc/flash_attn/flash_api.cpp
View file @
2d8ea9a5
...
@@ -282,11 +282,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -282,11 +282,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
const
int
seqlenq_nheads_swapped
=
seqlen_q
==
1
&&
num_heads_k
==
1
&&
num_heads
>
1
and
p_dropout
==
0.
f
and
head_size_og
%
8
==
0
;
// H/t Daniel Haziza
if
(
seqlenq_nheads_swapped
)
{
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
;
q
=
q
.
transpose
(
1
,
2
);
if
(
seqlenq_ngroups_swapped
)
{
std
::
swap
(
seqlen_q
,
num_heads
);
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
seqlen_q
=
ngroups
;
num_heads
=
num_heads_k
;
}
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
...
@@ -353,9 +356,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -353,9 +356,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
is_causal
);
is_causal
);
// This needs to match with run_mha_fwd_splitkv_dispatch
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
...
@@ -369,6 +370,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -369,6 +370,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
}
TORCH_CHECK
(
params
.
num_splits
<=
128
,
"num_splits > 128 not supported"
);
}
}
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
...
@@ -397,11 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -397,11 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
}
}
if
(
seqlenq_n
head
s_swapped
)
{
if
(
seqlenq_n
group
s_swapped
)
{
out
=
out
.
transpose
(
1
,
2
);
out
=
out
.
transpose
(
1
,
2
)
.
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
})
;
out_padded
=
out_padded
.
transpose
(
1
,
2
);
out_padded
=
out_padded
.
transpose
(
1
,
2
)
.
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
})
;
q_padded
=
q_padded
.
transpose
(
1
,
2
);
q_padded
=
q_padded
.
transpose
(
1
,
2
)
.
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
})
;
softmax_lse
=
softmax_lse
.
transpose
(
1
,
2
);
softmax_lse
=
softmax_lse
.
reshape
({
batch_size
,
num_heads_k
*
seqlen_q
,
1
}
);
}
}
return
{
out
,
q_padded
,
k_padded
,
v_padded
,
out_padded
,
softmax_lse
,
p
,
rng_state
};
return
{
out
,
q_padded
,
k_padded
,
v_padded
,
out_padded
,
softmax_lse
,
p
,
rng_state
};
}
}
...
@@ -1050,11 +1052,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1050,11 +1052,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
const
int
seqlenq_nheads_swapped
=
seqlen_q
==
1
&&
num_heads_k
==
1
&&
num_heads
>
1
;
// H/t Daniel Haziza
if
(
seqlenq_nheads_swapped
)
{
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
head_size_og
%
8
==
0
;
q
=
q
.
transpose
(
1
,
2
);
if
(
seqlenq_ngroups_swapped
)
{
std
::
swap
(
seqlen_q
,
num_heads
);
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
seqlen_q
=
ngroups
;
num_heads
=
num_heads_k
;
}
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
...
@@ -1184,12 +1189,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1184,12 +1189,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params
.
rotary_dim
=
0
;
params
.
rotary_dim
=
0
;
}
}
// This needs to match with run_mha_fwd_splitkv_dispatch
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
(
params
.
knew_ptr
==
nullptr
?
0
:
seqlen_q
)
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
...
@@ -1197,6 +1199,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1197,6 +1199,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if
(
num_splits
<
1
)
{
if
(
num_splits
<
1
)
{
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
}
}
TORCH_CHECK
(
params
.
num_splits
<=
128
,
"num_splits > 128 not supported"
);
if
(
params
.
num_splits
>
1
)
{
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
...
@@ -1219,9 +1222,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1219,9 +1222,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
}
}
}
if
(
seqlenq_n
head
s_swapped
)
{
if
(
seqlenq_n
group
s_swapped
)
{
out
=
out
.
transpose
(
1
,
2
);
out
=
out
.
transpose
(
1
,
2
)
.
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
})
;
softmax_lse
=
softmax_lse
.
transpose
(
1
,
2
);
softmax_lse
=
softmax_lse
.
reshape
({
batch_size
,
num_heads_k
*
seqlen_q
,
1
}
);
}
}
return
{
out
,
softmax_lse
};
return
{
out
,
softmax_lse
};
}
}
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
2d8ea9a5
...
@@ -123,14 +123,12 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -123,14 +123,12 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
kernel_dkv
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
Kernel_traits
::
kSmemKVSize
,
stream
>>>
(
params
);
kernel_dkv
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
Kernel_traits
::
kSmemKVSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
//
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
if
(
configure
)
return
;
if
(
configure
)
return
;
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
}
//
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
2d8ea9a5
...
@@ -1141,19 +1141,18 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
...
@@ -1141,19 +1141,18 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
,
typename
Params
>
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
combine_attn_seqk_parallel
(
const
Params
&
params
)
{
inline
__device__
void
combine_attn_seqk_parallel
(
const
Params
&
params
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
constexpr
int
kMaxSplits
=
1
<<
Log_max_splits
;
constexpr
int
kMaxSplits
=
1
<<
Log_max_splits
;
constexpr
int
kBlockM
=
16
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNThreads
=
Kernel_traits
::
kNThreads
;
static_assert
(
kMaxSplits
<=
128
,
"kMaxSplits must be <= 128"
);
static_assert
(
kMaxSplits
<=
128
,
"kMaxSplits must be <= 128"
);
// static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer");
static_assert
(
kBlockM
==
4
||
kBlockM
==
8
||
kBlockM
==
16
||
kBlockM
==
32
,
"kBlockM must be 4, 8, 16 or 32"
);
static_assert
(
kBlockM
==
16
||
kBlockM
==
32
,
"kBlockM must be 16 or 32"
);
static_assert
(
kNThreads
==
128
,
"We assume that each block has 128 threads"
);
static_assert
(
Kernel_traits
::
kNThreads
==
128
,
"We assume that each block has 128 threads"
);
// Shared memory.
// Shared memory.
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
...
@@ -1169,17 +1168,17 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1169,17 +1168,17 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
make_stride
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
_1
{}));
make_stride
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
Kernel_traits
::
kNThreads
-
1
)
/
Kernel_traits
::
kNThreads
;
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
constexpr
int
kRowsPerLoadLSE
=
Kernel_traits
::
kNThreads
/
kBlockM
;
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse
_accum(l)
); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
}
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
__syncthreads
();
__syncthreads
();
...
@@ -1187,7 +1186,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1187,7 +1186,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
constexpr
int
kRowsPerLoadTranspose
=
std
::
min
(
kRowsPerLoadLSE
,
kMaxSplits
);
constexpr
int
kRowsPerLoadTranspose
=
std
::
min
(
kRowsPerLoadLSE
,
kMaxSplits
);
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
//
16
rows, so each time we load we can load
8
rows).
//
kBlockM
rows, so each time we load we can load
128 / kBlockM
rows).
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
// static_assert(kThreadsPerSplit <= 32);
// static_assert(kThreadsPerSplit <= 32);
static_assert
(
kRowsPerLoadTranspose
<=
32
);
static_assert
(
kRowsPerLoadTranspose
<=
32
);
...
@@ -1230,7 +1229,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1230,7 +1229,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyOaccum
gmem_tiled_copy_Oaccum
;
constexpr
int
kBlockN
=
kNThreads
/
kBlockM
;
using
GmemLayoutAtomOaccum
=
Layout
<
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
,
Stride
<
Int
<
kBlockN
>
,
_1
>>
;
using
GmemTiledCopyOaccum
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomOaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
GmemTiledCopyOaccum
gmem_tiled_copy_Oaccum
;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
gOaccum
);
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
gOaccum
);
Tensor
tOrO
=
make_tensor
<
ElementAccum
>
(
shape
(
tOgOaccum
));
Tensor
tOrO
=
make_tensor
<
ElementAccum
>
(
shape
(
tOgOaccum
));
...
@@ -1247,7 +1252,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1247,7 +1252,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
for
(
int
k
=
0
;
k
<
size
(
tOpOaccum
);
++
k
)
{
tOpOaccum
(
k
)
=
get
<
1
>
(
tOcOaccum
(
0
,
0
,
k
))
<
params
.
d
;
}
for
(
int
k
=
0
;
k
<
size
(
tOpOaccum
);
++
k
)
{
tOpOaccum
(
k
)
=
get
<
1
>
(
tOcOaccum
(
0
,
0
,
k
))
<
params
.
d
;
}
}
}
// Load Oaccum in then scale and accumulate to O
// Load Oaccum in then scale and accumulate to O
#pragma unroll 2
for
(
int
split
=
0
;
split
<
params
.
num_splits
;
++
split
)
{
for
(
int
split
=
0
;
split
<
params
.
num_splits
;
++
split
)
{
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_tiled_copy_Oaccum
,
tOgOaccum
,
tOrOaccum
,
tOcOaccum
,
tOpOaccum
,
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
gmem_tiled_copy_Oaccum
,
tOgOaccum
,
tOrOaccum
,
tOcOaccum
,
tOpOaccum
,
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
...
@@ -1263,11 +1267,11 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1263,11 +1267,11 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
tOrO
(
i
,
m
,
k
)
+=
lse_scale
*
tOrOaccum
(
i
,
m
,
k
);
tOrO
(
i
,
m
,
k
)
+=
lse_scale
*
tOrOaccum
(
i
,
m
,
k
);
}
}
}
}
// if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum);
print(tOrO);
}
// if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }
}
}
tOgOaccum
.
data
()
=
tOgOaccum
.
data
()
+
params
.
b
*
params
.
h
*
params
.
seqlen_q
*
params
.
d_rounded
;
tOgOaccum
.
data
()
=
tOgOaccum
.
data
()
+
params
.
b
*
params
.
h
*
params
.
seqlen_q
*
params
.
d_rounded
;
}
}
// if (cute::thread0()) { print(tOrO); }
// if (cute::thread0()) { print
_tensor
(tOrO); }
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
tOrO
);
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
tOrO
);
// Write to gO
// Write to gO
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
2d8ea9a5
...
@@ -20,10 +20,10 @@ __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
...
@@ -20,10 +20,10 @@ __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
static_assert
(
Log_max_splits
>=
1
);
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
Log_max_splits
,
Is_even_K
>
(
params
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
kBlockM
,
Log_max_splits
,
Is_even_K
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
...
@@ -93,22 +93,26 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -93,22 +93,26 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
if
(
params
.
num_splits
>
1
)
{
if
(
params
.
num_splits
>
1
)
{
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr
int
kBlockM
=
Kernel_traits
::
kHeadDim
%
128
==
0
?
4
:
(
Kernel_traits
::
kHeadDim
%
64
==
0
?
8
:
16
);
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
kBlockM
-
1
)
/
kBlockM
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
128
)
{
}
else
if
(
params
.
num_splits
<=
128
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
7
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
7
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
...
...
tests/test_flash_attn.py
View file @
2d8ea9a5
...
@@ -1505,12 +1505,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
...
@@ -1505,12 +1505,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [
1
.0])
# @pytest.mark.parametrize("rotary_fraction", [
0
.0])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [
64
])
# @pytest.mark.parametrize("d", [
128
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
"seqlen_q,seqlen_k"
,
[
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment