Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
148fa857
Commit
148fa857
authored
Sep 11, 2023
by
danyao12
Browse files
Merge branch 'mha-train-develop' into mha-train-develop-dropout8bit
parents
af695bee
21ef37b4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
58 deletions
+50
-58
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+6
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+19
-22
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+6
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+19
-22
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
148fa857
...
@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
};
};
// dP Gemm (type 1 rcc)
// dP Gemm (type 1 rcc)
template
<
typename
BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
struct
Gemm0
struct
Gemm0
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2
=
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2
();
template
<
typename
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
template
<
typename
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
&
b_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
const
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
&
b_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
...
@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
n
0_
n
1_
n
2_
n
3_
k2
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
N
0_
N
1_
N
2_
N
3_
K2
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// set up dP Gemm (type 1 rcc)
// set up dP Gemm (type 1 rcc)
//
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
>
;
// dP: blockwise gemm
// dP: blockwise gemm
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
...
@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
148fa857
...
@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
};
// dP Gemm (type 1 rcc, B in Vgpr)
// dP Gemm (type 1 rcc, B in Vgpr)
template
<
typename
BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
struct
Gemm0
struct
Gemm0
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
();
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
...
@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
k
2_
n
0_
n
1_
n
2_
n
3_
k3
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
K
2_
N
0_
N
1_
N
2_
N
3_
K3
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
};
// dQ Gemm (type 3 crr)
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
,
typename
BSrcBlockDesc_N0_K_N1
>
struct
Gemm2
struct
Gemm2
{
{
private:
private:
...
@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_n0_k_n1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_K0_M_K1
>
template
<
typename
ABlockDesc_K0_M_K1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
...
@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
{
{
const
auto
N0_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I0
);
const
auto
N0_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I0
);
const
auto
K_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I1
);
const
auto
K_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I1
);
const
auto
N1_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I2
);
const
auto
N1_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
b_b
lock
_d
esc_
n
0_
k_n1
,
BSrcB
lock
D
esc_
N
0_
K_N1
{}
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_pass_through_transform
(
K_
)),
// 128 // 128
make_pass_through_transform
(
K_
)),
// 128 // 128
...
@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
};
// S Gemm (type 4 rcc, B in LDS)
// S Gemm (type 4 rcc, B in LDS)
template
<
typename
BSrcBlockDesc_K0_N_K1
>
struct
Gemm3
struct
Gemm3
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
GemmDataType
,
GemmDataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_b
lock
_d
esc_
bk0_n_bk1
)
,
BSrcB
lock
D
esc_
K0_N_K1
,
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
b_b
lock
_d
esc_
bk0_n_bk1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
BSrcB
lock
D
esc_
K0_N_K1
{}
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up dP Gemm (type 1 rcc)
// set up dP Gemm (type 1 rcc)
//
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
>
;
// Gemm0: LDS allocation for A and B: be careful of alignment
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up S Gemm (type 4 rcc)
// set up S Gemm (type 4 rcc)
//
//
using
Gemm3
=
Gemm3
<
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm3: LDS allocation for A and B: be careful of alignment
// Gemm3: LDS allocation for A and B: be careful of alignment
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
//
//
// set up dQ Gemm (type 3 crr)
// set up dQ Gemm (type 3 crr)
//
//
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
)
>
;
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
),
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
148fa857
...
@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
};
// dP Gemm (type 1 rcc)
// dP Gemm (type 1 rcc)
template
<
typename
BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
struct
Gemm0
struct
Gemm0
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2
=
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2
();
template
<
typename
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
template
<
typename
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
&
b_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
const
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
&
b_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
...
@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
n
0_
n
1_
n
2_
n
3_
k2
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
N
0_
N
1_
N
2_
N
3_
K2
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
@@ -1606,6 +1603,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1606,6 +1603,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// set up dP Gemm (type 1 rcc)
// set up dP Gemm (type 1 rcc)
//
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
>
;
// dP: blockwise gemm
// dP: blockwise gemm
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
...
@@ -2018,7 +2017,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2018,7 +2017,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
148fa857
...
@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
};
// dP Gemm (type 1 rcc, B in Vgpr)
// dP Gemm (type 1 rcc, B in Vgpr)
template
<
typename
BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
struct
Gemm0
struct
Gemm0
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
();
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
...
@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
k
2_
n
0_
n
1_
n
2_
n
3_
k3
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
K
2_
N
0_
N
1_
N
2_
N
3_
K3
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
};
// dQ Gemm (type 3 crr)
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
,
typename
BSrcBlockDesc_N0_K_N1
>
struct
Gemm2
struct
Gemm2
{
{
private:
private:
...
@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_n0_k_n1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_K0_M_K1
>
template
<
typename
ABlockDesc_K0_M_K1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
...
@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
{
{
const
auto
N0_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I0
);
const
auto
N0_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I0
);
const
auto
K_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I1
);
const
auto
K_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I1
);
const
auto
N1_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I2
);
const
auto
N1_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
b_b
lock
_d
esc_
n
0_
k_n1
,
BSrcB
lock
D
esc_
N
0_
K_N1
{}
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_pass_through_transform
(
K_
)),
// 128 // 128
make_pass_through_transform
(
K_
)),
// 128 // 128
...
@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
};
// S Gemm (type 4 rcc, B in LDS)
// S Gemm (type 4 rcc, B in LDS)
template
<
typename
BSrcBlockDesc_K0_N_K1
>
struct
Gemm3
struct
Gemm3
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType
,
GemmDataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_b
lock
_d
esc_
bk0_n_bk1
)
,
BSrcB
lock
D
esc_
K0_N_K1
,
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
b_b
lock
_d
esc_
bk0_n_bk1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
BSrcB
lock
D
esc_
K0_N_K1
{}
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
@@ -1652,6 +1643,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1652,6 +1643,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up dP Gemm (type 1 rcc)
// set up dP Gemm (type 1 rcc)
//
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
>
;
// Gemm0: LDS allocation for A and B: be careful of alignment
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
@@ -1688,6 +1681,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1688,6 +1681,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up S Gemm (type 4 rcc)
// set up S Gemm (type 4 rcc)
//
//
using
Gemm3
=
Gemm3
<
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm3: LDS allocation for A and B: be careful of alignment
// Gemm3: LDS allocation for A and B: be careful of alignment
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
@@ -1793,7 +1788,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1793,7 +1788,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
//
// set up dQ Gemm (type 3 crr)
// set up dQ Gemm (type 3 crr)
//
//
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
)
>
;
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
),
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
@@ -2093,7 +2090,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2093,7 +2090,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment