Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9448264d
Commit
9448264d
authored
Jan 14, 2024
by
Tri Dao
Browse files
Remove seqq_parallel backward kernel that's not used
parent
1274ec3e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
535 deletions
+25
-535
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+0
-450
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+25
-82
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+0
-3
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
9448264d
...
...
@@ -1141,442 +1141,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_N
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
// constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr
int
MMA_N_SdP
=
kBlockN
/
decltype
(
size
<
1
>
(
typename
Kernel_traits
::
TiledMmaSdP
::
TiledShape_MNK
{}))
::
value
;
constexpr
int
AtomLayoutMS
=
Kernel_traits
::
AtomLayoutMSdP
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_N
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
,
kBlockN
));
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
index_t
row_offset_do
=
binfo
.
q_offset
(
params
.
do_batch_stride
,
params
.
do_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
do_row_stride
+
bidh
*
params
.
do_head_stride
;
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
// We'll advance gdKaccum and gdVaccum before the first write.
const
index_t
row_offset_dkv_accum
=
((
bidb
*
params
.
h_k
+
(
bidh
/
params
.
h_h_k_ratio
))
*
params
.
seqlen_k_rounded
+
n_block_max
*
kBlockN
)
*
params
.
d_rounded
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
// We assume that params.d == kHeadDim for now
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
Tensor
gdO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
do_ptr
)
+
row_offset_do
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
do_row_stride
,
_1
{}));
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gdKaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dk_accum_ptr
)
+
row_offset_dkv_accum
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
gdVaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dv_accum_ptr
)
+
row_offset_dkv_accum
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutQdO
{});
Tensor
sQt
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutQdOtransposed
{});
Tensor
sQtNoSwizzle
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutQdOtransposedNoSwizzle
{});
Tensor
sdO
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutQdO
{});
Tensor
sdOt
=
make_tensor
(
sdO
.
data
(),
typename
Kernel_traits
::
SmemLayoutQdOtransposed
{});
Tensor
sdOtransposedNoSwizzle
=
make_tensor
(
sdO
.
data
(),
typename
Kernel_traits
::
SmemLayoutQdOtransposedNoSwizzle
{});
Tensor
sK
=
make_tensor
(
sdO
.
data
()
+
size
(
sdO
),
typename
Kernel_traits
::
SmemLayoutKV
{});
// Double buffer for sK
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
2
*
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sKt
=
make_tensor
(
sK
.
data
(),
typename
Kernel_traits
::
SmemLayoutKtransposed
{});
Tensor
sKtNoSwizzle
=
make_tensor
(
sK
.
data
(),
typename
Kernel_traits
::
SmemLayoutKtransposedNoSwizzle
{});
Tensor
sdS
=
make_tensor
(
sV
.
data
()
+
size
(
sV
),
typename
Kernel_traits
::
SmemLayoutPdS
{});
Tensor
sdSt
=
make_tensor
(
sdS
.
data
(),
typename
Kernel_traits
::
SmemLayoutPdStransposed
{});
Tensor
sdStNoSwizzle
=
make_tensor
(
sdS
.
data
(),
typename
Kernel_traits
::
SmemLayoutPdStransposedNoSwizzle
{});
Tensor
sP
=
make_tensor
(
sdS
.
data
()
+
size
(
sdS
),
typename
Kernel_traits
::
SmemLayoutPdS
{});
Tensor
sPt
=
make_tensor
(
sP
.
data
(),
typename
Kernel_traits
::
SmemLayoutPdStransposed
{});
Tensor
sPtNoSwizzle
=
make_tensor
(
sP
.
data
(),
typename
Kernel_traits
::
SmemLayoutPdStransposedNoSwizzle
{});
Tensor
sdPsum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
sdS
.
data
().
get
())),
Shape
<
Int
<
kBlockM
>>
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydO
gmem_tiled_copy_dO
;
auto
gmem_thr_copy_dO
=
gmem_tiled_copy_dO
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
gmem_tiled_copy_dKVaccum
;
auto
gmem_thr_copy_dKVaccum
=
gmem_tiled_copy_dKVaccum
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tdOgdO
=
gmem_thr_copy_dO
.
partition_S
(
gdO
);
Tensor
tdOsdO
=
gmem_thr_copy_dO
.
partition_D
(
sdO
);
Tensor
tdOgO
=
gmem_thr_copy_dO
.
partition_S
(
gO
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKVaccum
.
partition_D
(
gdKaccum
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKVaccum
.
partition_D
(
gdVaccum
);
typename
Kernel_traits
::
TiledMmaSdP
tiled_mma_sdp
;
auto
thr_mma_sdp
=
tiled_mma_sdp
.
get_thread_slice
(
tidx
);
Tensor
tSrQ
=
thr_mma_sdp
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_N,MMA_K)
Tensor
tSrK
=
thr_mma_sdp
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tdPrdO
=
thr_mma_sdp
.
partition_fragment_A
(
sdO
);
// (MMA,MMA_N,MMA_K)
Tensor
tdPrV
=
thr_mma_sdp
.
partition_fragment_B
(
sV
);
// (MMA,MMA_N,MMA_K)
typename
Kernel_traits
::
TiledMmadKV
tiled_mma_dkv
;
auto
thr_mma_dkv
=
tiled_mma_dkv
.
get_thread_slice
(
tidx
);
Tensor
tdKrdSt
=
thr_mma_dkv
.
partition_fragment_A
(
sdStNoSwizzle
);
// (MMA, MMA_N, MMA_N)
Tensor
tdKrQt
=
thr_mma_dkv
.
partition_fragment_B
(
sQtNoSwizzle
);
// (MMA, MMA_K, MMA_N)
Tensor
tdVrPt
=
thr_mma_dkv
.
partition_fragment_A
(
sPtNoSwizzle
);
// (MMA, MMA_N, MMA_N)
Tensor
tdVrdO
=
thr_mma_dkv
.
partition_fragment_B
(
sdOtransposedNoSwizzle
);
// (MMA, MMA_K, MMA_N)
typename
Kernel_traits
::
TiledMmadQ
tiled_mma_dq
;
auto
thr_mma_dq
=
tiled_mma_dq
.
get_thread_slice
(
tidx
);
Tensor
tdQrdS
=
thr_mma_dq
.
partition_fragment_A
(
sdS
);
// (MMA, MMA_N, MMA_N)
Tensor
tdQrKt
=
thr_mma_dq
.
partition_fragment_B
(
sKtNoSwizzle
);
// (MMA, MMA_K, MMA_N)
Tensor
acc_dq
=
partition_fragment_C
(
tiled_mma_dq
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M_SdP, MMA_K
//
// Copy Atom retiling
//
auto
smem_tiled_copy_QdO
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
);
auto
smem_thr_copy_QdO
=
smem_tiled_copy_QdO
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_QdO
.
partition_S
(
sQ
);
Tensor
tdPsdO
=
smem_thr_copy_QdO
.
partition_S
(
sdO
);
auto
smem_tiled_copy_KV
=
make_tiled_copy_B_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
);
auto
smem_thr_copy_KV
=
smem_tiled_copy_KV
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_KV
.
partition_S
(
sK
);
Tensor
tdPsV
=
smem_thr_copy_KV
.
partition_S
(
sV
);
// Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv
auto
smem_tiled_copy_PdS
=
make_tiled_copy_C_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtomPdS
{},
tiled_mma_sdp
);
auto
smem_thr_copy_PdS
=
smem_tiled_copy_PdS
.
get_thread_slice
(
tidx
);
Tensor
tPsP
=
smem_thr_copy_PdS
.
partition_D
(
sP
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tdSsdS
=
smem_thr_copy_PdS
.
partition_D
(
sdS
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
auto
smem_tiled_copy_PdSt
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
);
auto
smem_thr_copy_PdSt
=
smem_tiled_copy_PdSt
.
get_thread_slice
(
tidx
);
Tensor
tdVsPt
=
smem_thr_copy_PdSt
.
partition_S
(
sPt
);
Tensor
tdKsdSt
=
smem_thr_copy_PdSt
.
partition_S
(
sdSt
);
auto
smem_tiled_copy_QdOt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
);
auto
smem_thr_copy_QdOt
=
smem_tiled_copy_QdOt
.
get_thread_slice
(
tidx
);
Tensor
tdVsdOt
=
smem_thr_copy_QdOt
.
partition_S
(
sdOt
);
Tensor
tdKsQt
=
smem_thr_copy_QdOt
.
partition_S
(
sQt
);
auto
smem_tiled_copy_dS
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_dq
);
auto
smem_thr_copy_dS
=
smem_tiled_copy_dS
.
get_thread_slice
(
tidx
);
Tensor
tdQsdS
=
smem_thr_copy_dS
.
partition_S
(
sdS
);
auto
smem_tiled_copy_Kt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dq
);
auto
smem_thr_copy_Kt
=
smem_tiled_copy_Kt
.
get_thread_slice
(
tidx
);
Tensor
tdQsKt
=
smem_thr_copy_Kt
.
partition_S
(
sKt
);
//
// PREDICATES
//
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_QKV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_QKV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tQpQ
);
++
k
)
{
tQpQ
(
k
)
=
get
<
1
>
(
tQcQ
(
0
,
0
,
k
))
<
params
.
d
;
}
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tKVpKV
);
++
k
)
{
tKVpKV
(
k
)
=
get
<
1
>
(
tKVcKV
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Prologue
Tensor
tdOrdO
=
make_fragment_like
(
tdOgdO
);
Tensor
tdOrO
=
make_fragment_like
(
tdOgO
);
// TODO: Might need to exit early and write 0 to gdQ.
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
int
n_block
=
n_block_max
-
1
;
if
(
n_block
%
2
==
1
)
{
tKsK
.
data
()
=
tKsK
.
data
()
+
size
(
sK
);
tSsK
.
data
()
=
tSsK
.
data
()
+
size
(
sK
);
tdQsKt
.
data
()
=
tdQsKt
.
data
()
+
size
(
sK
);
}
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
Tensor
caccS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor
taccScS
=
thr_mma_sdp
.
partition_C
(
caccS
);
// (MMA,MMA_N,MMA_N)
static_assert
(
decltype
(
size
<
0
>
(
taccScS
))
::
value
==
4
);
// Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices.
Tensor
taccScS_row
=
logical_divide
(
taccScS
,
Shape
<
_2
>
{})(
make_coord
(
0
,
_
),
_
,
0
);
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
cute
::
cp_async_fence
();
Tensor
dP_sum
=
make_fragment_like
(
lse
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
sdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
__syncthreads
();
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
dP_sum
);
++
mi
)
{
dP_sum
(
mi
)
=
sdPsum
(
get
<
0
>
(
taccScS_row
(
mi
)));
}
flash
::
Dropout
dropout
(
params
.
rng_state
[
0
],
params
.
rng_state
[
1
],
params
.
p_dropout_in_uint8_t
,
bidb
,
bidh
,
tidx
,
params
.
h
);
clear
(
acc_dq
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
for
(;
n_block
>=
0
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M_SdP, MMA_N)
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
alibi_slope
);
}
// We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would
// be some finite value for those indices. In the end when we multiply with K to get dQ,
// the corresponding values of K would be 0, so the result would still be correct.
if
(
Is_causal
&&
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
);
}
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
scores
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
);
}
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
!
Is_dropout
?
flash
::
convert_type
<
Element
>
(
scores
)
:
flash
::
convert_type_relu
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
Tensor
tPaP
=
smem_thr_copy_PdS
.
retile_S
(
tPrP
);
// ((Atom,AtomNum), MMA_N, MMA_N)
cute
::
copy
(
smem_tiled_copy_PdS
,
tPaP
,
tPsP
);
Tensor
acc_dp
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
acc_dp
)
==
size
<
0
>
(
acc_s
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
acc_dp
)
==
size
<
1
>
(
acc_s
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
acc_dp
)
==
size
<
2
>
(
acc_s
));
// MMA
clear
(
acc_dp
);
flash
::
gemm
(
acc_dp
,
tdPrdO
,
tdPrV
,
tdPsdO
,
tdPsV
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor
dS
=
make_tensor
(
acc_dp
.
data
(),
scores
.
layout
());
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
return
p
*
(
!
Is_dropout
||
p
>=
0
?
dp
-
d
:
d
);
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
dS
);
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
dS
);
++
ni
)
{
dS
(
mi
,
ni
)
=
pointwise_mult
(
scores
(
mi
,
ni
),
dS
(
mi
,
ni
),
dP_sum
(
mi
));
}
}
Tensor
dS_reshaped
=
make_tensor
(
dS
.
data
(),
acc_dp
.
layout
());
// Convert dS from fp32 to fp16
Tensor
tdSrdS
=
flash
::
convert_type
<
Element
>
(
dS_reshaped
);
Tensor
tdSadS
=
smem_thr_copy_PdS
.
retile_S
(
tdSrdS
);
// ((Atom,AtomNum), MMA_N, MMA_N)
cute
::
copy
(
smem_tiled_copy_PdS
,
tdSadS
,
tdSsdS
);
__syncthreads
();
if
(
n_block
>
0
)
{
// Double buffer for sK
const
int
sK_offset
=
n_block
%
2
==
0
?
size
(
sK
)
:
-
size
(
sK
);
tKsK
.
data
()
=
tKsK
.
data
()
+
sK_offset
;
tSsK
.
data
()
=
tSsK
.
data
()
+
sK_offset
;
// Advance gK, gV
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
}
Tensor
acc_dv
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
clear
(
acc_dv
);
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); }
tdVgdVaccum
.
data
()
=
tdVgdVaccum
.
data
()
+
(
-
int
(
kBlockN
*
params
.
d_rounded
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
acc_dv
);
++
i
)
{
atomicAdd
(
&
tdVgdVaccum
(
i
),
acc_dv
(
i
));
}
__syncthreads
();
Tensor
acc_dk
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
clear
(
acc_dk
);
flash
::
gemm
(
acc_dk
,
tdKrdSt
,
tdKrQt
,
tdKsdSt
,
tdKsQt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
tdKgdKaccum
.
data
()
=
tdKgdKaccum
.
data
()
+
(
-
int
(
kBlockN
*
params
.
d_rounded
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
acc_dk
);
++
i
)
{
atomicAdd
(
&
tdKgdKaccum
(
i
),
acc_dk
(
i
));
}
flash
::
gemm
(
acc_dq
,
tdQrdS
,
tdQrKt
,
tdQsdS
,
tdQsKt
,
tiled_mma_dq
,
smem_tiled_copy_dS
,
smem_tiled_copy_Kt
,
smem_thr_copy_dS
,
smem_thr_copy_Kt
);
// Double buffer for sK
tdQsKt
.
data
()
=
tdQsKt
.
data
()
+
(
n_block
%
2
==
0
?
size
(
sK
)
:
-
size
(
sK
));
}
// Epilogue
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
acc_dq
);
++
i
)
{
acc_dq
(
i
)
*=
params
.
scale_softmax_rp_dropout
;
}
// Convert acc_dq from fp32 to fp16
Tensor
rdQ
=
flash
::
convert_type
<
Element
>
(
acc_dq
);
Tensor
sdQ
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutdQ
{});
// Partition sdV and sdK to match the accumulator partitioning
auto
smem_tiled_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
);
auto
smem_thr_copy_dQ
=
smem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
__syncthreads
();
cute
::
copy
(
smem_tiled_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
const
index_t
row_offset_dq
=
binfo
.
q_offset
(
params
.
dq_batch_stride
,
params
.
dq_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
dq_row_stride
+
bidh
*
params
.
dq_head_stride
;
Tensor
gdQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dq_ptr
)
+
row_offset_dq
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dq_row_stride
,
_1
{}));
typename
Kernel_traits
::
GmemTiledCopydQ
gmem_tiled_copy_dQ
;
auto
gmem_thr_copy_dQ
=
gmem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
__syncthreads
();
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
cute
::
copy
(
gmem_tiled_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
Tensor
tdQpdQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdQgdQ
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdQpdQ
);
++
k
)
{
tdQpdQ
(
k
)
=
get
<
1
>
(
tdQcdQ
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dQ
,
tdQrdQ
,
tdQgdQ
,
tdQcdQ
,
tdQpdQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_M
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv
(
const
Params
&
params
)
{
...
...
@@ -1618,19 +1182,5 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_N
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_seqq_parallel
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
;
compute_dq_dk_dv_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_N
,
Is_even_K
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
9448264d
...
...
@@ -29,11 +29,6 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params pa
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_N
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
(
Flash_bwd_params
params
)
{
flash
::
compute_dq_dk_dv_seqq_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_N
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_convert_dq_kernel
(
Flash_bwd_params
params
,
const
int
nsplits
)
{
flash
::
convert_dQ
<
Kernel_traits
>
(
params
,
nsplits
);
...
...
@@ -100,48 +95,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd_seqq_parallel
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
const
int
num_n_block
=
(
params
.
seqlen_k
+
Kernel_traits
::
kBlockN
-
1
)
/
Kernel_traits
::
kBlockN
;
dim3
grid_n
(
num_n_block
,
params
.
b
,
params
.
h_k
);
flash_bwd_clear_dkvaccum_kernel
<
Kernel_traits
><<<
grid_n
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid_m
(
num_m_block
,
params
.
b
,
params
.
h
);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_k as well.
const
bool
is_even_N
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
constexpr
int
smem_size_dq_dk_dv
=
Kernel_traits
::
kSmemSize1rowblock
;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
kernel
<<<
grid_m
,
Kernel_traits
::
kNThreads
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
auto
kernel_dkv
=
&
flash_bwd_convert_dkv_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dkv
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemKVSize
));
}
kernel_dkv
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
Kernel_traits
::
kSmemKVSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
if
(
configure
)
return
;
...
...
@@ -202,7 +155,6 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
// }
}
});
...
...
@@ -231,20 +183,16 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// if (params.h == params.h_k) {
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
constexpr
(
!
Is_dropout
)
{
// 92KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
else
{
// 116 KB
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
}
else
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
constexpr
(
!
Is_dropout
)
{
// 92KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
else
{
// 116 KB
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
//
} else {
//
run_flash_bwd
_seqq_parallel
<Flash_bwd_kernel_traits<Headdim, 128,
64,
8,
4
, 4, 4,
fals
e, false, T>>(params, stream, configure);
//
}
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
tru
e
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
});
}
...
...
@@ -261,29 +209,24 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
}
else
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
}
else
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
// }
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
});
}
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
9448264d
...
...
@@ -288,9 +288,6 @@ struct Flash_bwd_kernel_traits : public Base {
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
kSmemPSize
:
std
::
max
(
kSmemKVSize
,
kSmemKVSize
/
2
+
kSmemdSSize
+
kSmemPSize
));
static
constexpr
int
kSmemSize1rowblock
=
kSmemQdOSize
/
3
*
2
+
kSmemKVSize
/
2
*
3
+
kSmemdSSize
+
kSmemPSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static_assert
(
kHeadDim
%
kGmemElemsPerLoad
==
0
,
"kHeadDim must be a multiple of kGmemElemsPerLoad"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment