Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7aa37568
Commit
7aa37568
authored
Aug 11, 2023
by
danyao12
Browse files
qloop dropout optimize
parent
4274096b
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
435 additions
and
380 deletions
+435
-380
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+6
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+15
-15
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+6
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+15
-15
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+7
-9
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+15
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+9
-11
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...u/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...u/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+211
-169
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
7aa37568
...
...
@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
...
...
@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
7aa37568
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
7aa37568
...
...
@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
...
...
@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8
,
true
,
1
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
7aa37568
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
7aa37568
...
...
@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
4
x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_
8
x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
block_sync_lds
();
...
...
@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
4
x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_
8
x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
block_sync_lds
();
...
...
@@ -218,21 +218,19 @@ struct BlockwiseDropout
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
,
typename
Step
,
typename
Offset
>
// N3*N4=8
template
<
typename
ZThreadBuffer
,
typename
Step
,
typename
Offset
>
__host__
__device__
void
GenerateZMatrixAttnFwd
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
4
x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
Offset
{});
ph
.
get_random_
8
x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{});
}
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7aa37568
...
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
typename
LSEGridDescriptor_M
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
...
...
@@ -73,8 +73,8 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
...
...
@@ -141,7 +141,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
...
...
@@ -174,7 +174,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
...
...
@@ -203,7 +203,7 @@ __global__ void
ignore
=
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
;
ignore
=
lse_grid_desc_m
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
...
...
@@ -263,6 +263,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
DropoutStep
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -564,6 +565,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
DropoutStep
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -735,8 +737,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
z_grid_desc_m_n_
);
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
...
...
@@ -791,8 +794,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -876,7 +879,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
...
...
@@ -909,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7aa37568
...
...
@@ -135,7 +135,7 @@ __global__ void
arg_ptr
[
group_id
].
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
...
@@ -173,7 +173,7 @@ __global__ void
arg_ptr
[
group_id
].
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
...
@@ -244,6 +244,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
DropoutStep
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -566,6 +567,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
DropoutStep
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -622,8 +624,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -768,12 +770,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
z_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
...
...
@@ -829,7 +827,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
7aa37568
...
...
@@ -1533,8 +1533,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...
...
@@ -1966,16 +1966,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_ten
s
or_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
7aa37568
...
...
@@ -1473,8 +1473,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...
...
@@ -1865,16 +1865,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_ten
s
or_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
7aa37568
...
...
@@ -110,6 +110,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// C desc for source in blockwise copy
...
...
@@ -119,10 +124,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -136,9 +140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -542,9 +544,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -646,8 +646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0NWaves
;
// 4 // 4
static
constexpr
index_t
GemmNWave
=
Gemm0MWaves
;
// 1 // 1
...
...
@@ -770,9 +769,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
// 1 // 1
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
// 1 // 1
static
constexpr
index_t
GemmKLoop
=
Gemm2_K
/
Sum_K
;
// 2 // 2
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
B_K2
=
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_K1
=
Sum_K
/
B_K2
/
B_K3
;
// 4
...
...
@@ -1570,8 +1568,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1759,7 +1757,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1774,23 +1771,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1806,15 +1807,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
7aa37568
...
...
@@ -121,6 +121,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
...
...
@@ -133,10 +138,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -150,9 +154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -522,9 +524,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -657,8 +657,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -709,9 +708,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_K0
,
Gemm2_N
,
B_K1
>
;
using
BThreadClusterLengths
=
...
...
@@ -1554,8 +1551,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1722,7 +1719,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1737,23 +1733,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1769,14 +1769,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
7aa37568
...
...
@@ -109,6 +109,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// C desc for source in blockwise copy
...
...
@@ -118,10 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -135,9 +139,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -563,9 +565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -667,8 +667,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0NWaves
;
// 4 // 4
static
constexpr
index_t
GemmNWave
=
Gemm0MWaves
;
// 1 // 1
...
...
@@ -791,9 +790,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
// 1 // 1
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
// 1 // 1
static
constexpr
index_t
GemmKLoop
=
Gemm2_K
/
Sum_K
;
// 2 // 2
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
B_K2
=
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_K1
=
Sum_K
/
B_K2
/
B_K3
;
// 4
...
...
@@ -1621,8 +1619,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1946,7 +1944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1961,23 +1958,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1993,15 +1994,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
7aa37568
...
...
@@ -120,6 +120,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
...
...
@@ -132,10 +137,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -149,9 +153,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -543,9 +545,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -678,8 +678,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -730,9 +729,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_K0
,
Gemm2_N
,
B_K1
>
;
using
BThreadClusterLengths
=
...
...
@@ -1582,8 +1579,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1862,7 +1859,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1877,23 +1873,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1909,14 +1909,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp
View file @
7aa37568
...
...
@@ -873,8 +873,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -1022,16 +1022,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
false
,
decltype
(
n0
),
decltype
(
i
)>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
acc_thread_buf
,
ph
,
z_ten
s
or_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7aa37568
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment