Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3524e13c
Commit
3524e13c
authored
Aug 13, 2023
by
Tri Dao
Browse files
Update to Cutlass 3.1
parent
364a5b4a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
229 additions
and
171 deletions
+229
-171
csrc/cutlass
csrc/cutlass
+1
-1
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+164
-118
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+39
-30
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+19
-16
tests/test_flash_attn.py
tests/test_flash_attn.py
+6
-6
No files found.
cutlass
@
6f474202
Compare
c4f6b8c6
...
6f474202
Subproject commit
c4f6b8c6bc94ff69048492fb34df0dfaf1983933
Subproject commit
6f47420213f757831fae65c686aa471749fa8d60
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
3524e13c
...
@@ -147,14 +147,16 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
...
@@ -147,14 +147,16 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
Tensor
dP_sum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dsoftmax_sum
)
+
row_offset_dpsum
),
Tensor
dP_sum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dsoftmax_sum
)
+
row_offset_dpsum
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
auto
gmem_thr_copy_dO
=
typename
Kernel_traits
::
GmemTiledCopydO
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydO
gmem_tiled_copy_dO
;
auto
gmem_thr_copy_dO
=
gmem_tiled_copy_dO
.
get_thread_slice
(
tidx
);
// TODO: careful, we're zeroing out dQaccum with type float4, but when
// TODO: careful, we're zeroing out dQaccum with type float4, but when
// we do atomicAdds, we use type float. The layouts are different. Check this.
// we do atomicAdds, we use type float. The layouts are different. Check this.
auto
gmem_thr_copy_dQ_accum
=
typename
Kernel_traits
::
GmemTiledCopydQaccum
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQaccum
gmem_tiled_copy_dQaccum
;
auto
gmem_thr_copy_dQaccum
=
gmem_tiled_copy_dQaccum
.
get_thread_slice
(
tidx
);
Tensor
tdOgdO
=
gmem_thr_copy_dO
.
partition_S
(
gdO
);
Tensor
tdOgdO
=
gmem_thr_copy_dO
.
partition_S
(
gdO
);
Tensor
tdOgO
=
gmem_thr_copy_dO
.
partition_S
(
gO
);
Tensor
tdOgO
=
gmem_thr_copy_dO
.
partition_S
(
gO
);
Tensor
tdQgdQaccum
=
gmem_thr_copy_dQ
_
accum
.
partition_D
(
gdQaccum
);
Tensor
tdQgdQaccum
=
gmem_thr_copy_dQaccum
.
partition_D
(
gdQaccum
);
Tensor
cdO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cdO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
tdOcdO
=
gmem_thr_copy_dO
.
partition_S
(
cdO
);
Tensor
tdOcdO
=
gmem_thr_copy_dO
.
partition_S
(
cdO
);
...
@@ -168,10 +170,10 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
...
@@ -168,10 +170,10 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
Tensor
tdOrdO
=
make_fragment_like
(
tdOgdO
);
Tensor
tdOrdO
=
make_fragment_like
(
tdOgdO
);
Tensor
tdOrO
=
make_fragment_like
(
tdOgO
);
Tensor
tdOrO
=
make_fragment_like
(
tdOgO
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgdO
,
tdOrdO
,
tdOcdO
,
tdOpdO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgdO
,
tdOrdO
,
tdOcdO
,
tdOpdO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgO
,
tdOrO
,
tdOcdO
,
tdOpdO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgO
,
tdOrO
,
tdOcdO
,
tdOpdO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
...
@@ -181,7 +183,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
...
@@ -181,7 +183,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
if
(
Clear_dQaccum
)
{
if
(
Clear_dQaccum
)
{
Tensor
zero
=
make_fragment_like
(
tdQgdQaccum
);
Tensor
zero
=
make_fragment_like
(
tdQgdQaccum
);
clear
(
zero
);
clear
(
zero
);
copy
(
gmem_t
hr
_copy_dQ
_
accum
,
zero
,
tdQgdQaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dQaccum
,
zero
,
tdQgdQaccum
);
}
}
}
}
...
@@ -213,13 +215,14 @@ inline __device__ void clear_dKVaccum(const Params ¶ms) {
...
@@ -213,13 +215,14 @@ inline __device__ void clear_dKVaccum(const Params ¶ms) {
Tensor
gdVaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dv_accum_ptr
)
+
row_offset_dkv_accum
),
Tensor
gdVaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dv_accum_ptr
)
+
row_offset_dkv_accum
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
auto
gmem_thr_copy_dKV_accum
=
typename
Kernel_traits
::
GmemTiledCopydQaccum
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQaccum
gmem_tiled_copy_dKVaccum
;
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKV_accum
.
partition_D
(
gdKaccum
);
auto
gmem_thr_copy_dKVaccum
=
gmem_tiled_copy_dKVaccum
.
get_thread_slice
(
tidx
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKV_accum
.
partition_D
(
gdVaccum
);
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKVaccum
.
partition_D
(
gdKaccum
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKVaccum
.
partition_D
(
gdVaccum
);
Tensor
zero
=
make_fragment_like
(
tdKgdKaccum
);
Tensor
zero
=
make_fragment_like
(
tdKgdKaccum
);
clear
(
zero
);
clear
(
zero
);
copy
(
gmem_t
hr
_copy_dKV
_
accum
,
zero
,
tdKgdKaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dKVaccum
,
zero
,
tdKgdKaccum
);
copy
(
gmem_t
hr
_copy_dKV
_
accum
,
zero
,
tdVgdVaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dKVaccum
,
zero
,
tdVgdVaccum
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -264,22 +267,25 @@ inline __device__ void convert_dQ(const Params ¶ms) {
...
@@ -264,22 +267,25 @@ inline __device__ void convert_dQ(const Params ¶ms) {
Tensor
sdQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
Tensor
sdQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutdQ
{});
typename
Kernel_traits
::
SmemLayoutdQ
{});
auto
gmem_thr_copy_dQ
=
typename
Kernel_traits
::
GmemTiledCopydQ
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQ
gmem_tiled_copy_dQ
;
auto
gmem_thr_copy_dQ_accum
=
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_dQ
=
gmem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
gmem_tiled_copy_dQaccum
;
auto
gmem_thr_copy_dQaccum
=
gmem_tiled_copy_dQaccum
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
TiledMmadQ
tiled_mma_dq
;
typename
Kernel_traits
::
TiledMmadQ
tiled_mma_dq
;
auto
smem_thr_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
);
auto
smem_thr_copy_dQ
=
smem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
Tensor
tdQgdQaccum
=
gmem_thr_copy_dQ
_
accum
.
partition_S
(
gdQaccum
);
Tensor
tdQgdQaccum
=
gmem_thr_copy_dQaccum
.
partition_S
(
gdQaccum
);
Tensor
acc_dq
=
partition_fragment_C
(
tiled_mma_dq
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
Tensor
acc_dq
=
partition_fragment_C
(
tiled_mma_dq
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V
(
size
(
acc_dq
)
==
size
(
tdQgdQaccum
));
CUTE_STATIC_ASSERT_V
(
size
(
acc_dq
)
==
size
(
tdQgdQaccum
));
Tensor
tdQrdQaccum
=
make_fragment_like
(
tdQgdQaccum
);
Tensor
tdQrdQaccum
=
make_fragment_like
(
tdQgdQaccum
);
copy
(
gmem_t
hr
_copy_dQ
_
accum
,
tdQgdQaccum
,
tdQrdQaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dQaccum
,
tdQgdQaccum
,
tdQrdQaccum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
acc_dq
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
(
acc_dq
);
++
i
)
{
acc_dq
(
i
)
=
tdQrdQaccum
(
i
)
*
params
.
scale_softmax_rp_dropout
;
acc_dq
(
i
)
=
tdQrdQaccum
(
i
)
*
params
.
scale_softmax_rp_dropout
;
...
@@ -287,10 +293,10 @@ inline __device__ void convert_dQ(const Params ¶ms) {
...
@@ -287,10 +293,10 @@ inline __device__ void convert_dQ(const Params ¶ms) {
// Convert acc_dq from fp32 to fp16
// Convert acc_dq from fp32 to fp16
Tensor
rdQ
=
flash
::
convert_type
<
Element
>
(
acc_dq
);
Tensor
rdQ
=
flash
::
convert_type
<
Element
>
(
acc_dq
);
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
cute
::
copy
(
smem_t
iled
_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
__syncthreads
();
__syncthreads
();
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
copy
(
gmem_t
hr
_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
cute
::
copy
(
gmem_t
iled
_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
...
@@ -299,7 +305,7 @@ inline __device__ void convert_dQ(const Params ¶ms) {
...
@@ -299,7 +305,7 @@ inline __device__ void convert_dQ(const Params ¶ms) {
for
(
int
k
=
0
;
k
<
size
(
tdQpdQ
);
++
k
)
{
tdQpdQ
(
k
)
=
get
<
1
>
(
tdQcdQ
(
0
,
0
,
k
))
<
params
.
d
;
}
for
(
int
k
=
0
;
k
<
size
(
tdQpdQ
);
++
k
)
{
tdQpdQ
(
k
)
=
get
<
1
>
(
tdQcdQ
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dQ
,
tdQrdQ
,
tdQgdQ
,
tdQcdQ
,
tdQpdQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dQ
,
tdQrdQ
,
tdQgdQ
,
tdQcdQ
,
tdQpdQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
}
}
...
@@ -354,11 +360,14 @@ inline __device__ void convert_dKV(const Params ¶ms) {
...
@@ -354,11 +360,14 @@ inline __device__ void convert_dKV(const Params ¶ms) {
typename
Kernel_traits
::
SmemLayoutdKV
{});
typename
Kernel_traits
::
SmemLayoutdKV
{});
Tensor
sdV
=
make_tensor
(
sdK
.
data
()
+
size
(
sdK
),
typename
Kernel_traits
::
SmemLayoutdKV
{});
// (SMEM_N, SMEM_K)
Tensor
sdV
=
make_tensor
(
sdK
.
data
()
+
size
(
sdK
),
typename
Kernel_traits
::
SmemLayoutdKV
{});
// (SMEM_N, SMEM_K)
auto
gmem_thr_copy_dKV
=
typename
Kernel_traits
::
GmemTiledCopydQ
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQ
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV_accum
=
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
gmem_tiled_copy_dKVaccum
;
auto
gmem_thr_copy_dKVaccum
=
gmem_tiled_copy_dKVaccum
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
TiledMmadKV
tiled_mma_dkv
;
typename
Kernel_traits
::
TiledMmadKV
tiled_mma_dkv
;
auto
smem_thr_copy_dKV
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdKV
{},
tiled_mma_dkv
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dKV
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdKV
{},
tiled_mma_dkv
);
auto
smem_thr_copy_dKV
=
smem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
taccdKsdK
=
smem_thr_copy_dKV
.
partition_D
(
sdK
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdKsdK
=
smem_thr_copy_dKV
.
partition_D
(
sdK
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdVsdV
=
smem_thr_copy_dKV
.
partition_D
(
sdV
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdVsdV
=
smem_thr_copy_dKV
.
partition_D
(
sdV
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
...
@@ -366,8 +375,8 @@ inline __device__ void convert_dKV(const Params ¶ms) {
...
@@ -366,8 +375,8 @@ inline __device__ void convert_dKV(const Params ¶ms) {
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVsdV
=
gmem_thr_copy_dKV
.
partition_S
(
sdV
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdVsdV
=
gmem_thr_copy_dKV
.
partition_S
(
sdV
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKV
_
accum
.
partition_S
(
gdKaccum
);
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKVaccum
.
partition_S
(
gdKaccum
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKV
_
accum
.
partition_S
(
gdVaccum
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKVaccum
.
partition_S
(
gdVaccum
);
Tensor
acc_dk
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
Tensor
acc_dk
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
Tensor
acc_dv
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
Tensor
acc_dv
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
...
@@ -376,8 +385,8 @@ inline __device__ void convert_dKV(const Params ¶ms) {
...
@@ -376,8 +385,8 @@ inline __device__ void convert_dKV(const Params ¶ms) {
Tensor
tdKrdKaccum
=
make_fragment_like
(
tdKgdKaccum
);
Tensor
tdKrdKaccum
=
make_fragment_like
(
tdKgdKaccum
);
Tensor
tdVrdVaccum
=
make_fragment_like
(
tdVgdVaccum
);
Tensor
tdVrdVaccum
=
make_fragment_like
(
tdVgdVaccum
);
copy
(
gmem_t
hr
_copy_dKV
_
accum
,
tdKgdKaccum
,
tdKrdKaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dKVaccum
,
tdKgdKaccum
,
tdKrdKaccum
);
copy
(
gmem_t
hr
_copy_dKV
_
accum
,
tdVgdVaccum
,
tdVrdVaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dKVaccum
,
tdVgdVaccum
,
tdVrdVaccum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
acc_dk
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
(
acc_dk
);
++
i
)
{
acc_dk
(
i
)
=
tdKrdKaccum
(
i
)
*
params
.
scale_softmax_rp_dropout
;
acc_dk
(
i
)
=
tdKrdKaccum
(
i
)
*
params
.
scale_softmax_rp_dropout
;
...
@@ -391,13 +400,13 @@ inline __device__ void convert_dKV(const Params ¶ms) {
...
@@ -391,13 +400,13 @@ inline __device__ void convert_dKV(const Params ¶ms) {
Tensor
rdV
=
flash
::
convert_type
<
Element
>
(
acc_dv
);
Tensor
rdV
=
flash
::
convert_type
<
Element
>
(
acc_dv
);
Tensor
taccdKrdK
=
smem_thr_copy_dKV
.
retile_S
(
rdK
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdKrdK
=
smem_thr_copy_dKV
.
retile_S
(
rdK
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdVrdV
=
smem_thr_copy_dKV
.
retile_S
(
rdV
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdVrdV
=
smem_thr_copy_dKV
.
retile_S
(
rdV
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_dKV
,
taccdKrdK
,
taccdKsdK
);
cute
::
copy
(
smem_t
iled
_copy_dKV
,
taccdKrdK
,
taccdKsdK
);
copy
(
smem_t
hr
_copy_dKV
,
taccdVrdV
,
taccdVsdV
);
cute
::
copy
(
smem_t
iled
_copy_dKV
,
taccdVrdV
,
taccdVsdV
);
__syncthreads
();
__syncthreads
();
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
copy
(
gmem_t
hr
_copy_dKV
,
tdKsdK
,
tdKrdK
);
cute
::
copy
(
gmem_t
iled
_copy_dKV
,
tdKsdK
,
tdKrdK
);
copy
(
gmem_t
hr
_copy_dKV
,
tdVsdV
,
tdVrdV
);
cute
::
copy
(
gmem_t
iled
_copy_dKV
,
tdVsdV
,
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cdKV
=
make_identity_tensor
(
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
...
@@ -406,10 +415,10 @@ inline __device__ void convert_dKV(const Params ¶ms) {
...
@@ -406,10 +415,10 @@ inline __device__ void convert_dKV(const Params ¶ms) {
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
false
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
}
}
...
@@ -511,20 +520,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -511,20 +520,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
sdPsum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
float2
*>
((
sP
.
data
()
+
cute
::
max
(
size
(
sP
),
size
(
sdQ
))).
get
())),
Tensor
sdPsum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
float2
*>
((
sP
.
data
()
+
cute
::
max
(
size
(
sP
),
size
(
sdQ
))).
get
())),
Shape
<
Int
<
Kernel_traits
::
kSmemdPsumCount
/
2
>>
{});
Shape
<
Int
<
Kernel_traits
::
kSmemdPsumCount
/
2
>>
{});
auto
gmem_thr_copy_QKV
=
typename
Kernel_traits
::
GmemTiledCopyQKV
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
using
GmemTiledCopydO
=
std
::
conditional_t
<
using
GmemTiledCopydO
=
std
::
conditional_t
<
Is_first
,
Is_first
,
typename
Kernel_traits
::
GmemTiledCopydO
,
typename
Kernel_traits
::
GmemTiledCopydO
,
typename
Kernel_traits
::
GmemTiledCopyQKV
typename
Kernel_traits
::
GmemTiledCopyQKV
>
;
>
;
auto
gmem_thr_copy_dO
=
GmemTiledCopydO
{}.
get_thread_slice
(
tidx
);
GmemTiledCopydO
gmem_tiled_copy_dO
;
auto
gmem_thr_copy_dQ
=
typename
Kernel_traits
::
GmemTiledCopydQ
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_dO
=
gmem_tiled_copy_dO
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQ
gmem_tiled_copy_dQ
;
auto
gmem_thr_copy_dQ
=
gmem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
using
GmemLayoutAtomdQaccum
=
std
::
conditional_t
<
using
GmemLayoutAtomdQaccum
=
std
::
conditional_t
<
!
Seq_parallel
,
!
Seq_parallel
,
typename
Kernel_traits
::
GmemTiledCopydQaccum
,
typename
Kernel_traits
::
GmemTiledCopydQaccum
,
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
>
;
>
;
auto
gmem_thr_copy_dQ_accum
=
GmemLayoutAtomdQaccum
{}.
get_thread_slice
(
tidx
);
GmemLayoutAtomdQaccum
gmem_tiled_copy_dQaccum
;
auto
gmem_thr_copy_dQaccum
=
gmem_tiled_copy_dQaccum
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
...
@@ -537,7 +550,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -537,7 +550,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
Tensor
tdQgdQaccum
=
gmem_thr_copy_dQ
_
accum
.
partition_D
(
gdQaccum
);
Tensor
tdQgdQaccum
=
gmem_thr_copy_dQaccum
.
partition_D
(
gdQaccum
);
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
// __syncthreads();
// __syncthreads();
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
...
@@ -570,12 +583,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -570,12 +583,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Copy Atom retiling
// Copy Atom retiling
//
//
auto
smem_thr_copy_QdO
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_QdO
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
);
auto
smem_thr_copy_QdO
=
smem_tiled_copy_QdO
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_QdO
.
partition_S
(
sQ
);
Tensor
tSsQ
=
smem_thr_copy_QdO
.
partition_S
(
sQ
);
Tensor
tdPsdO
=
smem_thr_copy_QdO
.
partition_S
(
sdO
);
Tensor
tdPsdO
=
smem_thr_copy_QdO
.
partition_S
(
sdO
);
// auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
// auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
auto
smem_thr_copy_KV
=
make_tiled_copy_B_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_KV
=
make_tiled_copy_B_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
);
auto
smem_thr_copy_KV
=
smem_tiled_copy_KV
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_KV
.
partition_S
(
sK
);
Tensor
tSsK
=
smem_thr_copy_KV
.
partition_S
(
sK
);
// if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
...
@@ -584,7 +599,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -584,7 +599,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Partition sP and sdS to match the accumulator partitioning
// Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv
// This has to be tiled_mma_sdp, not tiled_mma_dkv
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
auto
smem_thr_copy_PdS
=
make_tiled_copy_C_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtomPdS
{},
tiled_mma_sdp
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_PdS
=
make_tiled_copy_C_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtomPdS
{},
tiled_mma_sdp
);
auto
smem_thr_copy_PdS
=
smem_tiled_copy_PdS
.
get_thread_slice
(
tidx
);
Tensor
tPsP
=
smem_thr_copy_PdS
.
partition_D
(
sP
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tPsP
=
smem_thr_copy_PdS
.
partition_D
(
sP
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
// if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
...
@@ -593,21 +609,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -593,21 +609,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// }
// }
Tensor
tdSsdS
=
smem_thr_copy_PdS
.
partition_D
(
sdS
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tdSsdS
=
smem_thr_copy_PdS
.
partition_D
(
sdS
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
auto
smem_thr_copy_PdSt
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_PdSt
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
);
auto
smem_thr_copy_PdSt
=
smem_tiled_copy_PdSt
.
get_thread_slice
(
tidx
);
Tensor
tdVsPt
=
smem_thr_copy_PdSt
.
partition_S
(
sPt
);
Tensor
tdVsPt
=
smem_thr_copy_PdSt
.
partition_S
(
sPt
);
Tensor
tdKsdSt
=
smem_thr_copy_PdSt
.
partition_S
(
sdSt
);
Tensor
tdKsdSt
=
smem_thr_copy_PdSt
.
partition_S
(
sdSt
);
auto
smem_thr_copy_QdOt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_QdOt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
);
auto
smem_thr_copy_QdOt
=
smem_tiled_copy_QdOt
.
get_thread_slice
(
tidx
);
Tensor
tdVsdOt
=
smem_thr_copy_QdOt
.
partition_S
(
sdOt
);
Tensor
tdVsdOt
=
smem_thr_copy_QdOt
.
partition_S
(
sdOt
);
Tensor
tdKsQt
=
smem_thr_copy_QdOt
.
partition_S
(
sQt
);
Tensor
tdKsQt
=
smem_thr_copy_QdOt
.
partition_S
(
sQt
);
auto
smem_thr_copy_dS
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dS
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_dq
);
auto
smem_thr_copy_dS
=
smem_tiled_copy_dS
.
get_thread_slice
(
tidx
);
Tensor
tdQsdS
=
smem_thr_copy_dS
.
partition_S
(
sdS
);
Tensor
tdQsdS
=
smem_thr_copy_dS
.
partition_S
(
sdS
);
auto
smem_thr_copy_Kt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_Kt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dq
);
auto
smem_thr_copy_Kt
=
smem_tiled_copy_Kt
.
get_thread_slice
(
tidx
);
Tensor
tdQsKt
=
smem_thr_copy_Kt
.
partition_S
(
sKt
);
Tensor
tdQsKt
=
smem_thr_copy_Kt
.
partition_S
(
sKt
);
auto
smem_thr_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
);
auto
smem_thr_copy_dQ
=
smem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
//
//
...
@@ -655,7 +676,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -655,7 +676,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
make_stride
(
params
.
dv_row_stride
,
_1
{}));
auto
gmem_thr_copy_dKV
=
typename
Kernel_traits
::
GmemTiledCopydKV
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
...
@@ -669,10 +691,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -669,10 +691,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
return
;
return
;
}
}
...
@@ -688,7 +710,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -688,7 +710,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
Kernel_traits
::
Is_V_in_regs
)
{
if
(
Kernel_traits
::
Is_V_in_regs
)
{
// Clear the smem tiles to account for predicated off loads
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
flash
::
cp_async_fence
();
flash
::
cp_async_fence
();
}
}
...
@@ -698,18 +720,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -698,18 +720,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
// Clear the smem tiles to account for predicated off loads
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgdO
,
tdOsdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgdO
,
tdOsdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
}
else
{
}
else
{
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
}
}
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
Tensor
caccS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor
caccS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (BLK_M,BLK_N) -> (blk_m,blk_n)
...
@@ -726,23 +748,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -726,23 +748,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
}
}
// Tensor tKrK = make_fragment_like(tKsK);
// Tensor tKrK = make_fragment_like(tKsK);
// // copy(gmem_t
hr
_copy_QKV, tKgK(_, _, _, 0), tKrK);
// //
cute::
copy(gmem_t
iled
_copy_QKV, tKgK(_, _, _, 0), tKrK);
// copy(gmem_t
hr
_copy_QKV, tKgK, tKrK);
//
cute::
copy(gmem_t
iled
_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }
// // if (cute::thread(1, 0)) { print(tKrK); }
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
if
(
!
Kernel_traits
::
Is_V_in_regs
)
{
if
(
!
Kernel_traits
::
Is_V_in_regs
)
{
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
}
}
flash
::
cp_async_fence
();
flash
::
cp_async_fence
();
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if
(
Is_first
)
{
if
(
Is_first
)
{
copy
(
tdOrdO
,
tdOsdO
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
sdPsum
,
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
sdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
}
}
...
@@ -752,7 +774,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -752,7 +774,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
__syncthreads
();
__syncthreads
();
Tensor
tdPrV_copy_view
=
smem_thr_copy_KV
.
retile_D
(
tdPrV
);
Tensor
tdPrV_copy_view
=
smem_thr_copy_KV
.
retile_D
(
tdPrV
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tdPsV
)
==
size
<
1
>
(
tdPrV_copy_view
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tdPsV
)
==
size
<
1
>
(
tdPrV_copy_view
));
// M
copy
(
smem_t
hr
_copy_KV
,
tdPsV
,
tdPrV_copy_view
);
cute
::
copy
(
smem_t
iled
_copy_KV
,
tdPsV
,
tdPrV_copy_view
);
}
}
auto
seed
=
params
.
rng_state
[
0
];
auto
seed
=
params
.
rng_state
[
0
];
...
@@ -775,10 +797,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -775,10 +797,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
// Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
// #pragma unroll
// #pragma unroll
// for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
// for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
// copy(smem_t
hr
_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
//
cute::
copy(smem_t
iled
_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
// }
// }
// if (cute::thread0()) { print(tSrK); }
// if (cute::thread0()) { print(tSrK); }
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
...
@@ -827,7 +850,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -827,7 +850,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
Tensor
tPaP
=
smem_thr_copy_PdS
.
retile_S
(
tPrP
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
tPaP
=
smem_thr_copy_PdS
.
retile_S
(
tPrP
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_PdS
,
tPaP
,
tPsP
);
cute
::
copy
(
smem_t
iled
_copy_PdS
,
tPaP
,
tPsP
);
// if (cute::thread0()) { print(tPaP); }
// if (cute::thread0()) { print(tPaP); }
// __syncthreads();
// __syncthreads();
// if (cute::thread0()) { print(sP); }
// if (cute::thread0()) { print(sP); }
...
@@ -850,7 +873,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -850,7 +873,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread0()) { print(dP_sum); }
// if (cute::thread0()) { print(dP_sum); }
flash
::
gemm
<
/*A_in_regs=*/
false
,
/*B_in_regs=*/
Kernel_traits
::
Is_V_in_regs
>
(
flash
::
gemm
<
/*A_in_regs=*/
false
,
/*B_in_regs=*/
Kernel_traits
::
Is_V_in_regs
>
(
acc_dp
,
tdPrdO
,
tdPrV
,
tdPsdO
,
tdPsV
,
tiled_mma_sdp
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
acc_dp
,
tdPrdO
,
tdPrV
,
tdPsdO
,
tdPsV
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
);
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
...
@@ -877,7 +901,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -877,7 +901,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
make_layout
(
get
<
0
>
(
acc_dq
.
layout
()),
make_layout
(
get
<
0
>
(
acc_dq
.
layout
()),
get
<
2
>
(
acc_dq
.
layout
()),
get
<
2
>
(
acc_dq
.
layout
()),
get
<
1
>
(
acc_dq
.
layout
())));
get
<
1
>
(
acc_dq
.
layout
())));
copy
(
gmem_t
hr
_copy_dQ
_
accum
,
tdQgdQaccum
,
acc_dq_reshaped
);
cute
::
copy
(
gmem_t
iled
_copy_dQaccum
,
tdQgdQaccum
,
acc_dq_reshaped
);
}
}
if
(
Double_buffer
&&
m_block
>
m_block_min
)
{
if
(
Double_buffer
&&
m_block
>
m_block_min
)
{
...
@@ -887,7 +911,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -887,7 +911,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tSsQ
.
data
()
=
tSsQ
.
data
()
+
sQ_offset
;
tSsQ
.
data
()
=
tSsQ
.
data
()
+
sQ_offset
;
// Advance gQ
// Advance gQ
tQgQ
.
data
()
=
tQgQ
.
data
()
+
(
-
int
(
kBlockM
*
params
.
q_row_stride
));
tQgQ
.
data
()
=
tQgQ
.
data
()
+
(
-
int
(
kBlockM
*
params
.
q_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
);
flash
::
cp_async_fence
();
flash
::
cp_async_fence
();
}
}
...
@@ -896,7 +920,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -896,7 +920,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
tdSrdS
=
flash
::
convert_type
<
Element
>
(
dS_reshaped
);
Tensor
tdSrdS
=
flash
::
convert_type
<
Element
>
(
dS_reshaped
);
// if (cute::thread0()) { print(tPrP); }
// if (cute::thread0()) { print(tPrP); }
Tensor
tdSadS
=
smem_thr_copy_PdS
.
retile_S
(
tdSrdS
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
tdSadS
=
smem_thr_copy_PdS
.
retile_S
(
tdSrdS
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_PdS
,
tdSadS
,
tdSsdS
);
cute
::
copy
(
smem_t
iled
_copy_PdS
,
tdSadS
,
tdSsdS
);
__syncthreads
();
__syncthreads
();
// Layout p_l = tPrP.layout();
// Layout p_l = tPrP.layout();
...
@@ -904,7 +928,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -904,7 +928,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0()) { print(acc_dv); }
// if (cute::thread0()) { print(acc_dv); }
...
@@ -915,15 +940,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -915,15 +940,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tdOgdO
.
data
()
=
tdOgdO
.
data
()
+
(
-
int
(
kBlockM
*
params
.
do_row_stride
));
tdOgdO
.
data
()
=
tdOgdO
.
data
()
+
(
-
int
(
kBlockM
*
params
.
do_row_stride
));
if
(
Is_first
)
{
if
(
Is_first
)
{
tdOgO
.
data
()
=
tdOgO
.
data
()
+
(
-
int
(
kBlockM
*
params
.
o_row_stride
));
tdOgO
.
data
()
=
tdOgO
.
data
()
+
(
-
int
(
kBlockM
*
params
.
o_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
);
}
else
{
}
else
{
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_dO
,
tdOgdO
,
tdOsdO
,
tQcQ
,
tQpQ
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_dO
,
tdOgdO
,
tdOsdO
,
tQcQ
,
tQpQ
);
flash
::
cp_async_fence
();
flash
::
cp_async_fence
();
}
}
}
}
flash
::
gemm
(
acc_dq
,
tdQrdS
,
tdQrKt
,
tdQsdS
,
tdQsKt
,
tiled_mma_dq
,
smem_thr_copy_dS
,
smem_thr_copy_Kt
);
flash
::
gemm
(
acc_dq
,
tdQrdS
,
tdQrKt
,
tdQsdS
,
tdQsKt
,
tiled_mma_dq
,
smem_tiled_copy_dS
,
smem_tiled_copy_Kt
,
smem_thr_copy_dS
,
smem_thr_copy_Kt
);
// if (cute::thread0()) { print(acc_dq); }
// if (cute::thread0()) { print(acc_dq); }
if
(
m_block
>
m_block_min
)
{
if
(
m_block
>
m_block_min
)
{
...
@@ -945,7 +971,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -945,7 +971,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
get
<
2
>
(
acc_dq
.
layout
()),
get
<
2
>
(
acc_dq
.
layout
()),
get
<
1
>
(
acc_dq
.
layout
())));
get
<
1
>
(
acc_dq
.
layout
())));
if
(
!
Seq_parallel
)
{
if
(
!
Seq_parallel
)
{
copy
(
gmem_t
hr
_copy_dQ
_
accum
,
acc_dq_reshaped
,
tdQgdQaccum
);
cute
::
copy
(
gmem_t
iled
_copy_dQaccum
,
acc_dq_reshaped
,
tdQgdQaccum
);
}
else
{
}
else
{
// if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
// if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
CUTE_STATIC_ASSERT_V
(
size
(
acc_dq
)
==
size
(
tdQgdQaccum
));
CUTE_STATIC_ASSERT_V
(
size
(
acc_dq
)
==
size
(
tdQgdQaccum
));
...
@@ -958,10 +984,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -958,10 +984,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Convert acc_dq from fp32 to fp16
// Convert acc_dq from fp32 to fp16
Tensor
rdQ
=
flash
::
convert_type
<
Element
>
(
acc_dq
);
Tensor
rdQ
=
flash
::
convert_type
<
Element
>
(
acc_dq
);
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
cute
::
copy
(
smem_t
iled
_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
}
}
flash
::
gemm
(
acc_dk
,
tdKrdSt
,
tdKrQt
,
tdKsdSt
,
tdKsQt
,
tiled_mma_dkv
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
flash
::
gemm
(
acc_dk
,
tdKrdSt
,
tdKrQt
,
tdKsdSt
,
tdKsQt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
// if (cute::thread0()) { print(acc_dk); }
// if (cute::thread0()) { print(acc_dk); }
if
(
Double_buffer
)
{
// Double buffer for sQ
if
(
Double_buffer
)
{
// Double buffer for sQ
tdKsQt
.
data
()
=
tdKsQt
.
data
()
+
(
m_block
%
2
==
0
?
size
(
sQ
)
:
-
size
(
sQ
));
tdKsQt
.
data
()
=
tdKsQt
.
data
()
+
(
m_block
%
2
==
0
?
size
(
sQ
)
:
-
size
(
sQ
));
...
@@ -970,12 +997,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -970,12 +997,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
__syncthreads
();
__syncthreads
();
// Advance gQ
// Advance gQ
tQgQ
.
data
()
=
tQgQ
.
data
()
+
(
-
int
(
kBlockM
*
params
.
q_row_stride
));
tQgQ
.
data
()
=
tQgQ
.
data
()
+
(
-
int
(
kBlockM
*
params
.
q_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
);
flash
::
cp_async_fence
();
flash
::
cp_async_fence
();
}
}
if
(
Is_first
&&
m_block
>
m_block_min
)
{
if
(
Is_first
&&
m_block
>
m_block_min
)
{
copy
(
tdOrdO
,
tdOsdO
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
sdPsum
,
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
sdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
}
}
...
@@ -983,14 +1010,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -983,14 +1010,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
Is_last
)
{
if
(
Is_last
)
{
__syncthreads
();
__syncthreads
();
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
copy
(
gmem_t
hr
_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
cute
::
copy
(
gmem_t
iled
_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
tdQgdQ
.
data
()
=
tdQgdQ
.
data
()
+
(
-
int
(
kBlockM
*
params
.
dq_row_stride
));
tdQgdQ
.
data
()
=
tdQgdQ
.
data
()
+
(
-
int
(
kBlockM
*
params
.
dq_row_stride
));
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
#pragma unroll
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tdQgdQ
);
++
m
)
{
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tdQgdQ
);
++
m
)
{
if
(
Is_even_MN
||
get
<
0
>
(
tdQcdQ
(
0
,
m
,
0
))
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
)
{
if
(
Is_even_MN
||
get
<
0
>
(
tdQcdQ
(
0
,
m
,
0
))
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
)
{
copy
(
gmem_t
hr
_copy_dQ
,
tdQrdQ
(
_
,
m
,
_
),
tdQgdQ
(
_
,
m
,
_
));
cute
::
copy
(
gmem_t
iled
_copy_dQ
,
tdQrdQ
(
_
,
m
,
_
),
tdQgdQ
(
_
,
m
,
_
));
}
}
}
}
}
}
...
@@ -1014,7 +1041,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -1014,7 +1041,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
sdV
=
make_tensor
(
sdK
.
data
()
+
size
(
sdK
),
typename
Kernel_traits
::
SmemLayoutdKV
{});
// (SMEM_N, SMEM_K)
Tensor
sdV
=
make_tensor
(
sdK
.
data
()
+
size
(
sdK
),
typename
Kernel_traits
::
SmemLayoutdKV
{});
// (SMEM_N, SMEM_K)
// Partition sdV and sdK to match the accumulator partitioning
// Partition sdV and sdK to match the accumulator partitioning
auto
smem_thr_copy_dKV
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdKV
{},
tiled_mma_dkv
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dKV
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdKV
{},
tiled_mma_dkv
);
auto
smem_thr_copy_dKV
=
smem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
taccdKrdK
=
smem_thr_copy_dKV
.
retile_S
(
rdK
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdKrdK
=
smem_thr_copy_dKV
.
retile_S
(
rdK
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdKsdK
=
smem_thr_copy_dKV
.
partition_D
(
sdK
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdKsdK
=
smem_thr_copy_dKV
.
partition_D
(
sdK
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdVrdV
=
smem_thr_copy_dKV
.
retile_S
(
rdV
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdVrdV
=
smem_thr_copy_dKV
.
retile_S
(
rdV
);
// ((Atom,AtomNum), MMA_N, MMA_N)
...
@@ -1026,8 +1054,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -1026,8 +1054,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// If Is_last, there's already a __syncthreads() at the end of the loop.
// If Is_last, there's already a __syncthreads() at the end of the loop.
if
(
!
Is_last
)
{
__syncthreads
();
}
if
(
!
Is_last
)
{
__syncthreads
();
}
copy
(
smem_t
hr
_copy_dKV
,
taccdKrdK
,
taccdKsdK
);
cute
::
copy
(
smem_t
iled
_copy_dKV
,
taccdKrdK
,
taccdKsdK
);
copy
(
smem_t
hr
_copy_dKV
,
taccdVrdV
,
taccdVsdV
);
cute
::
copy
(
smem_t
iled
_copy_dKV
,
taccdVrdV
,
taccdVsdV
);
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
...
@@ -1040,7 +1068,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -1040,7 +1068,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
make_stride
(
params
.
dv_row_stride
,
_1
{}));
auto
gmem_thr_copy_dKV
=
typename
Kernel_traits
::
GmemTiledCopydKV
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKsdK
=
gmem_thr_copy_dKV
.
partition_S
(
sdK
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdKsdK
=
gmem_thr_copy_dKV
.
partition_S
(
sdK
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVsdV
=
gmem_thr_copy_dKV
.
partition_S
(
sdV
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdVsdV
=
gmem_thr_copy_dKV
.
partition_S
(
sdV
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
...
@@ -1048,9 +1077,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -1048,9 +1077,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
__syncthreads
();
__syncthreads
();
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
copy
(
gmem_t
hr
_copy_dKV
,
tdKsdK
,
tdKrdK
);
cute
::
copy
(
gmem_t
iled
_copy_dKV
,
tdKsdK
,
tdKrdK
);
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
copy
(
gmem_t
hr
_copy_dKV
,
tdVsdV
,
tdVrdV
);
cute
::
copy
(
gmem_t
iled
_copy_dKV
,
tdVsdV
,
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sdK
),
size
<
1
>
(
sdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sdK
),
size
<
1
>
(
sdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
...
@@ -1058,10 +1087,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -1058,10 +1087,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
}
}
...
@@ -1163,9 +1192,12 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1163,9 +1192,12 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
sdPsum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
sdS
.
data
().
get
())),
Tensor
sdPsum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
sdS
.
data
().
get
())),
Shape
<
Int
<
kBlockM
>>
{});
Shape
<
Int
<
kBlockM
>>
{});
auto
gmem_thr_copy_QKV
=
typename
Kernel_traits
::
GmemTiledCopyQKV
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_dO
=
typename
Kernel_traits
::
GmemTiledCopydO
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_dKV_accum
=
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydO
gmem_tiled_copy_dO
;
auto
gmem_thr_copy_dO
=
gmem_tiled_copy_dO
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQaccumAtomicAdd
gmem_tiled_copy_dKVaccum
;
auto
gmem_thr_copy_dKVaccum
=
gmem_tiled_copy_dKVaccum
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
...
@@ -1176,8 +1208,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1176,8 +1208,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKV
_
accum
.
partition_D
(
gdKaccum
);
Tensor
tdKgdKaccum
=
gmem_thr_copy_dKVaccum
.
partition_D
(
gdKaccum
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKV
_
accum
.
partition_D
(
gdVaccum
);
Tensor
tdVgdVaccum
=
gmem_thr_copy_dKVaccum
.
partition_D
(
gdVaccum
);
typename
Kernel_traits
::
TiledMmaSdP
tiled_mma_sdp
;
typename
Kernel_traits
::
TiledMmaSdP
tiled_mma_sdp
;
auto
thr_mma_sdp
=
tiled_mma_sdp
.
get_thread_slice
(
tidx
);
auto
thr_mma_sdp
=
tiled_mma_sdp
.
get_thread_slice
(
tidx
);
...
@@ -1204,32 +1236,39 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1204,32 +1236,39 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// Copy Atom retiling
// Copy Atom retiling
//
//
auto
smem_thr_copy_QdO
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_QdO
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
);
auto
smem_thr_copy_QdO
=
smem_tiled_copy_QdO
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_QdO
.
partition_S
(
sQ
);
Tensor
tSsQ
=
smem_thr_copy_QdO
.
partition_S
(
sQ
);
Tensor
tdPsdO
=
smem_thr_copy_QdO
.
partition_S
(
sdO
);
Tensor
tdPsdO
=
smem_thr_copy_QdO
.
partition_S
(
sdO
);
auto
smem_thr_copy_KV
=
make_tiled_copy_B_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_KV
=
make_tiled_copy_B_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_sdp
);
auto
smem_thr_copy_KV
=
smem_tiled_copy_KV
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_KV
.
partition_S
(
sK
);
Tensor
tSsK
=
smem_thr_copy_KV
.
partition_S
(
sK
);
Tensor
tdPsV
=
smem_thr_copy_KV
.
partition_S
(
sV
);
Tensor
tdPsV
=
smem_thr_copy_KV
.
partition_S
(
sV
);
// Partition sP and sdS to match the accumulator partitioning
// Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv
// This has to be tiled_mma_sdp, not tiled_mma_dkv
auto
smem_thr_copy_PdS
=
make_tiled_copy_C_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtomPdS
{},
tiled_mma_sdp
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_PdS
=
make_tiled_copy_C_warpcontiguousN
<
MMA_N_SdP
>
(
typename
Kernel_traits
::
SmemCopyAtomPdS
{},
tiled_mma_sdp
);
auto
smem_thr_copy_PdS
=
smem_tiled_copy_PdS
.
get_thread_slice
(
tidx
);
Tensor
tPsP
=
smem_thr_copy_PdS
.
partition_D
(
sP
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tPsP
=
smem_thr_copy_PdS
.
partition_D
(
sP
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tdSsdS
=
smem_thr_copy_PdS
.
partition_D
(
sdS
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
tdSsdS
=
smem_thr_copy_PdS
.
partition_D
(
sdS
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
auto
smem_thr_copy_PdSt
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_PdSt
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
);
auto
smem_thr_copy_PdSt
=
smem_tiled_copy_PdSt
.
get_thread_slice
(
tidx
);
Tensor
tdVsPt
=
smem_thr_copy_PdSt
.
partition_S
(
sPt
);
Tensor
tdVsPt
=
smem_thr_copy_PdSt
.
partition_S
(
sPt
);
Tensor
tdKsdSt
=
smem_thr_copy_PdSt
.
partition_S
(
sdSt
);
Tensor
tdKsdSt
=
smem_thr_copy_PdSt
.
partition_S
(
sdSt
);
auto
smem_thr_copy_QdOt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_QdOt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dkv
);
auto
smem_thr_copy_QdOt
=
smem_tiled_copy_QdOt
.
get_thread_slice
(
tidx
);
Tensor
tdVsdOt
=
smem_thr_copy_QdOt
.
partition_S
(
sdOt
);
Tensor
tdVsdOt
=
smem_thr_copy_QdOt
.
partition_S
(
sdOt
);
Tensor
tdKsQt
=
smem_thr_copy_QdOt
.
partition_S
(
sQt
);
Tensor
tdKsQt
=
smem_thr_copy_QdOt
.
partition_S
(
sQt
);
auto
smem_thr_copy_dS
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dS
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma_dq
);
auto
smem_thr_copy_dS
=
smem_tiled_copy_dS
.
get_thread_slice
(
tidx
);
Tensor
tdQsdS
=
smem_thr_copy_dS
.
partition_S
(
sdS
);
Tensor
tdQsdS
=
smem_thr_copy_dS
.
partition_S
(
sdS
);
auto
smem_thr_copy_Kt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_Kt
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma_dq
);
auto
smem_thr_copy_Kt
=
smem_tiled_copy_Kt
.
get_thread_slice
(
tidx
);
Tensor
tdQsKt
=
smem_thr_copy_Kt
.
partition_S
(
sKt
);
Tensor
tdQsKt
=
smem_thr_copy_Kt
.
partition_S
(
sKt
);
//
//
...
@@ -1263,15 +1302,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1263,15 +1302,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// TODO: Might need to exit early and write 0 to gdQ.
// TODO: Might need to exit early and write 0 to gdQ.
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
int
n_block
=
n_block_max
-
1
;
int
n_block
=
n_block_max
-
1
;
...
@@ -1282,10 +1321,10 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1282,10 +1321,10 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
}
}
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
Tensor
caccS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor
caccS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (BLK_M,BLK_N) -> (blk_m,blk_n)
...
@@ -1304,7 +1343,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1304,7 +1343,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
Tensor
dP_sum
=
make_fragment_like
(
lse
);
Tensor
dP_sum
=
make_fragment_like
(
lse
);
copy
(
tdOrdO
,
tdOsdO
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
sdPsum
,
sdPsum
,
tdOrdO
,
tdOrO
,
sdPsum
,
sdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
...
@@ -1324,7 +1363,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1324,7 +1363,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
...
@@ -1359,7 +1399,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1359,7 +1399,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
Tensor
tPaP
=
smem_thr_copy_PdS
.
retile_S
(
tPrP
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
tPaP
=
smem_thr_copy_PdS
.
retile_S
(
tPrP
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_PdS
,
tPaP
,
tPsP
);
cute
::
copy
(
smem_t
iled
_copy_PdS
,
tPaP
,
tPsP
);
Tensor
acc_dp
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
Tensor
acc_dp
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
acc_dp
)
==
size
<
0
>
(
acc_s
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
acc_dp
)
==
size
<
0
>
(
acc_s
));
// MMA
...
@@ -1367,7 +1407,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1367,7 +1407,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
acc_dp
)
==
size
<
2
>
(
acc_s
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
acc_dp
)
==
size
<
2
>
(
acc_s
));
// MMA
clear
(
acc_dp
);
clear
(
acc_dp
);
flash
::
gemm
(
acc_dp
,
tdPrdO
,
tdPrV
,
tdPsdO
,
tdPsV
,
tiled_mma_sdp
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
flash
::
gemm
(
acc_dp
,
tdPrdO
,
tdPrV
,
tdPsdO
,
tdPsV
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor
dS
=
make_tensor
(
acc_dp
.
data
(),
scores
.
layout
());
Tensor
dS
=
make_tensor
(
acc_dp
.
data
(),
scores
.
layout
());
...
@@ -1386,7 +1427,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1386,7 +1427,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// Convert dS from fp32 to fp16
// Convert dS from fp32 to fp16
Tensor
tdSrdS
=
flash
::
convert_type
<
Element
>
(
dS_reshaped
);
Tensor
tdSrdS
=
flash
::
convert_type
<
Element
>
(
dS_reshaped
);
Tensor
tdSadS
=
smem_thr_copy_PdS
.
retile_S
(
tdSrdS
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
tdSadS
=
smem_thr_copy_PdS
.
retile_S
(
tdSrdS
);
// ((Atom,AtomNum), MMA_N, MMA_N)
copy
(
smem_t
hr
_copy_PdS
,
tdSadS
,
tdSsdS
);
cute
::
copy
(
smem_t
iled
_copy_PdS
,
tdSadS
,
tdSsdS
);
__syncthreads
();
__syncthreads
();
if
(
n_block
>
0
)
{
if
(
n_block
>
0
)
{
...
@@ -1397,8 +1438,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1397,8 +1438,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// Advance gK, gV
// Advance gK, gV
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -1406,7 +1447,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1406,7 +1447,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
acc_dv
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
Tensor
acc_dv
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
clear
(
acc_dv
);
clear
(
acc_dv
);
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); }
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); }
tdVgdVaccum
.
data
()
=
tdVgdVaccum
.
data
()
+
(
-
int
(
kBlockN
*
params
.
d_rounded
));
tdVgdVaccum
.
data
()
=
tdVgdVaccum
.
data
()
+
(
-
int
(
kBlockN
*
params
.
d_rounded
));
#pragma unroll
#pragma unroll
...
@@ -1415,12 +1457,14 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1415,12 +1457,14 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
__syncthreads
();
__syncthreads
();
Tensor
acc_dk
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
Tensor
acc_dk
=
partition_fragment_C
(
tiled_mma_dkv
,
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_N, MMA_K
clear
(
acc_dk
);
clear
(
acc_dk
);
flash
::
gemm
(
acc_dk
,
tdKrdSt
,
tdKrQt
,
tdKsdSt
,
tdKsQt
,
tiled_mma_dkv
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
flash
::
gemm
(
acc_dk
,
tdKrdSt
,
tdKrQt
,
tdKsdSt
,
tdKsQt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
tdKgdKaccum
.
data
()
=
tdKgdKaccum
.
data
()
+
(
-
int
(
kBlockN
*
params
.
d_rounded
));
tdKgdKaccum
.
data
()
=
tdKgdKaccum
.
data
()
+
(
-
int
(
kBlockN
*
params
.
d_rounded
));
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
acc_dk
);
++
i
)
{
atomicAdd
(
&
tdKgdKaccum
(
i
),
acc_dk
(
i
));
}
for
(
int
i
=
0
;
i
<
size
(
acc_dk
);
++
i
)
{
atomicAdd
(
&
tdKgdKaccum
(
i
),
acc_dk
(
i
));
}
flash
::
gemm
(
acc_dq
,
tdQrdS
,
tdQrKt
,
tdQsdS
,
tdQsKt
,
tiled_mma_dq
,
smem_thr_copy_dS
,
smem_thr_copy_Kt
);
flash
::
gemm
(
acc_dq
,
tdQrdS
,
tdQrKt
,
tdQsdS
,
tdQsKt
,
tiled_mma_dq
,
smem_tiled_copy_dS
,
smem_tiled_copy_Kt
,
smem_thr_copy_dS
,
smem_thr_copy_Kt
);
// Double buffer for sK
// Double buffer for sK
tdQsKt
.
data
()
=
tdQsKt
.
data
()
+
(
n_block
%
2
==
0
?
size
(
sK
)
:
-
size
(
sK
));
tdQsKt
.
data
()
=
tdQsKt
.
data
()
+
(
n_block
%
2
==
0
?
size
(
sK
)
:
-
size
(
sK
));
...
@@ -1436,12 +1480,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1436,12 +1480,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
sdQ
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutdQ
{});
Tensor
sdQ
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutdQ
{});
// Partition sdV and sdK to match the accumulator partitioning
// Partition sdV and sdK to match the accumulator partitioning
auto
smem_thr_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_dQ
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomdQ
{},
tiled_mma_dq
);
auto
smem_thr_copy_dQ
=
smem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdQrdQ
=
smem_thr_copy_dQ
.
retile_S
(
rdQ
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdQsdQ
=
smem_thr_copy_dQ
.
partition_D
(
sdQ
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
__syncthreads
();
__syncthreads
();
copy
(
smem_t
hr
_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
cute
::
copy
(
smem_t
iled
_copy_dQ
,
taccdQrdQ
,
taccdQsdQ
);
const
index_t
row_offset_dq
=
binfo
.
q_offset
(
params
.
dq_batch_stride
,
params
.
dq_row_stride
,
bidb
)
const
index_t
row_offset_dq
=
binfo
.
q_offset
(
params
.
dq_batch_stride
,
params
.
dq_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
dq_row_stride
+
bidh
*
params
.
dq_head_stride
;
+
m_block
*
kBlockM
*
params
.
dq_row_stride
+
bidh
*
params
.
dq_head_stride
;
...
@@ -1449,14 +1494,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1449,14 +1494,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dq_row_stride
,
_1
{}));
make_stride
(
params
.
dq_row_stride
,
_1
{}));
auto
gmem_thr_copy_dQ
=
typename
Kernel_traits
::
GmemTiledCopydQ
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopydQ
gmem_tiled_copy_dQ
;
auto
gmem_thr_copy_dQ
=
gmem_tiled_copy_dQ
.
get_thread_slice
(
tidx
);
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQsdQ
=
gmem_thr_copy_dQ
.
partition_S
(
sdQ
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
Tensor
tdQgdQ
=
gmem_thr_copy_dQ
.
partition_D
(
gdQ
);
__syncthreads
();
__syncthreads
();
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
Tensor
tdQrdQ
=
make_tensor
<
Element
>
(
shape
(
tdQgdQ
));
copy
(
gmem_t
hr
_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
cute
::
copy
(
gmem_t
iled
_copy_dQ
,
tdQsdQ
,
tdQrdQ
);
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cdQ
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
...
@@ -1467,7 +1513,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1467,7 +1513,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_dQ
,
tdQrdQ
,
tdQgdQ
,
tdQcdQ
,
tdQpdQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_dQ
,
tdQrdQ
,
tdQgdQ
,
tdQcdQ
,
tdQpdQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
}
}
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
3524e13c
...
@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
...
@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash
::
reduce_sum
(
scores
,
scores_sum
);
flash
::
reduce_sum
(
scores
,
scores_sum
);
}
else
{
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
scores_max
);
Tensor
scores_max_prev
=
make_fragment_like
(
scores_max
);
copy
(
scores_max
,
scores_max_prev
);
cute
::
copy
(
scores_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores_max
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
...
@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
...
@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
>
inline
__device__
void
write_softmax_to_gmem
(
inline
__device__
void
write_softmax_to_gmem
(
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
&
tPgP
,
TiledCopy
gmem_t
hr
_copy_P
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
&
tPgP
,
TiledCopy
gmem_t
iled
_copy_P
)
{
)
{
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout
l
=
tOrP
.
layout
();
Layout
l
=
tOrP
.
layout
();
...
@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem(
...
@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tPrP
)
==
size
<
1
>
(
tPgP
));
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tPrP
)
==
size
<
1
>
(
tPgP
));
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
tPrP
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
tPrP
);
++
mi
)
{
copy
(
gmem_t
hr
_copy_P
,
tPrP
(
_
,
mi
),
tPgP
(
_
,
mi
,
0
));
cute
::
copy
(
gmem_t
iled
_copy_P
,
tPrP
(
_
,
mi
),
tPgP
(
_
,
mi
,
0
));
}
}
};
};
...
@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
auto
gmem_thr_copy_QKV
=
typename
Kernel_traits
::
GmemTiledCopyQKV
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_P
=
typename
Kernel_traits
::
GmemTiledCopyP
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyP
gmem_tiled_copy_P
;
auto
gmem_thr_copy_P
=
gmem_tiled_copy_P
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
...
@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Copy Atom retiling
// Copy Atom retiling
//
//
auto
smem_thr_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto
smem_thr_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
auto
smem_thr_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
// TODO: this might need to change if we change the mma instruction in SM70
// TODO: this might need to change if we change the mma instruction in SM70
...
@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
...
@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
__syncthreads
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
copy
(
smem_t
hr
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
cute
::
copy
(
smem_t
iled
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
__syncthreads
();
}
}
int
n_block
=
n_block_max
-
1
;
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...
@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
__syncthreads
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
copy
(
smem_t
hr
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
cute
::
copy
(
smem_t
iled
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
...
@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Advance gV
// Advance gV
if
(
masking_step
>
0
)
{
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
}
else
{
// Clear the smem tiles to account for predicated off loads
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
}
}
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_thr_copy_Q
,
smem_thr_copy_K
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
...
@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
if
(
n_block
>
0
)
{
if
(
n_block
>
0
)
{
// Advance gK
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
copy
(
tOrP
,
tOrP_copy
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
hr
_copy_P
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
iled
_copy_P
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
...
@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
}
}
// if (cute::thread0()) { print(tOrP); }
// if (cute::thread0()) { print(tOrP); }
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_thr_copy_V
);
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
// This check is at the end of the loop since we always have at least 1 iteration
...
@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
__syncthreads
();
__syncthreads
();
// Advance gV
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_thr_copy_Q
,
smem_thr_copy_K
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
...
@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
if
(
n_block
>
0
)
{
if
(
n_block
>
0
)
{
// Advance gK
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
copy
(
tOrP
,
tOrP_copy
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
hr
_copy_P
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
iled
_copy_P
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
...
@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
block_row_idx
,
block_col_idx
,
kNWarps
);
block_row_idx
,
block_col_idx
,
kNWarps
);
}
}
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_thr_copy_V
);
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
}
// Epilogue
// Epilogue
...
@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
Tensor
sO
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
Tensor
sO
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
// Partition sO to match the accumulator partitioning
auto
smem_thr_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
...
@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// sO has the same size as sQ, so we don't need to sync here.
// sO has the same size as sQ, so we don't need to sync here.
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
__syncthreads
();
}
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
__syncthreads
();
}
copy
(
smem_t
hr
_copy_O
,
taccOrO
,
taccOsO
);
cute
::
copy
(
smem_t
iled
_copy_O
,
taccOrO
,
taccOsO
);
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
...
@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
auto
gmem_thr_copy_O
=
typename
Kernel_traits
::
GmemTiledCopyO
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOsO
=
gmem_thr_copy_O
.
partition_S
(
sO
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOsO
=
gmem_thr_copy_O
.
partition_S
(
sO
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
__syncthreads
();
__syncthreads
();
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
copy
(
gmem_t
hr
_copy_O
,
tOsO
,
tOrO
);
cute
::
copy
(
gmem_t
iled
_copy_O
,
tOsO
,
tOrO
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma
.
partition_C
(
caccO
);
// (MMA,MMA_M,MMA_K)
Tensor
taccOcO
=
thr_mma
.
partition_C
(
caccO
);
// (MMA,MMA_M,MMA_K)
...
@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
}
}
...
...
csrc/flash_attn/src/utils.h
View file @
3524e13c
...
@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) {
...
@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) {
template
<
bool
A_in_regs
=
false
,
bool
B_in_regs
=
false
,
typename
Tensor0
,
typename
Tensor1
,
template
<
bool
A_in_regs
=
false
,
bool
B_in_regs
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
Tensor4
,
typename
Tensor2
,
typename
Tensor3
,
typename
Tensor4
,
typename
TiledMma
,
typename
TiledCopy0
,
typename
TiledCopy1
>
typename
TiledMma
,
typename
TiledCopyA
,
typename
TiledCopyB
,
typename
ThrCopyA
,
typename
ThrCopyB
>
inline
__device__
void
gemm
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsA
,
inline
__device__
void
gemm
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsA
,
Tensor4
const
&
tCsB
,
TiledMma
tiled_mma
,
Tensor4
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy0
smem_thr_copy_A
,
TiledCopy1
smem_thr_copy_B
)
{
TiledCopyA
smem_tiled_copy_A
,
TiledCopyB
smem_tiled_copy_B
,
ThrCopyA
smem_thr_copy_A
,
ThrCopyB
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
...
@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
...
@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
tCrA_copy_view
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
tCrA_copy_view
));
// M
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
if
(
!
A_in_regs
)
{
copy
(
smem_t
hr
_copy_A
,
tCsA
(
_
,
_
,
_0
{}),
tCrA_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
A_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_A
,
tCsA
(
_
,
_
,
_0
{}),
tCrA_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
B_in_regs
)
{
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
B_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
if
(
!
A_in_regs
)
{
copy
(
smem_t
hr
_copy_A
,
tCsA
(
_
,
_
,
i
+
1
),
tCrA_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
A_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_A
,
tCsA
(
_
,
_
,
i
+
1
),
tCrA_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
B_in_regs
)
{
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
B_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
}
}
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
}
}
...
@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
...
@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
TiledMma
,
typename
TiledCopy
>
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
inline
__device__
void
gemm_A_in_regs
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
inline
__device__
void
gemm_A_in_regs
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy
smem_thr_copy_B
)
{
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
ThrCopy
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
}
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
}
}
...
@@ -319,7 +322,7 @@ void cp_async_wait() {
...
@@ -319,7 +322,7 @@ void cp_async_wait() {
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
t
hr
_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
inline
__device__
void
copy
(
TiledCopy
t
iled
_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
int
max_MN
=
0
)
{
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
...
@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
...
@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
copy
(
t
hr
_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
cute
::
copy
(
t
iled
_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
}
else
if
(
Clear_OOB_K
)
{
clear
(
D
(
_
,
m
,
k
));
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
}
}
else
if
(
Clear_OOB_MN
)
{
}
else
if
(
Clear_OOB_MN
)
{
clear
(
D
(
_
,
m
,
_
));
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
// TD [2023-04-13]: Strange that the code below can cause race condition.
...
@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
...
@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(t
hr
_copy, S(_, m, _), D(_, m, _));
// copy(t
iled
_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// clear(D(_, m, _));
// }
// }
...
@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
...
@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(t
hr
_copy, S(_, m, k), D(_, m, k));
// copy(t
iled
_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// clear(D(_, m, k));
// }
// }
...
...
tests/test_flash_attn.py
View file @
3524e13c
...
@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
...
@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
#
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
])
#
@pytest.mark.parametrize('dtype', [torch.float16])
#
@pytest.mark.parametrize('causal', [False, True])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
])
#
@pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
#
@pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
32
,
64
,
96
,
128
,
160
,
192
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
])
#
@pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
128
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
128
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment