Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
63e3f3c4
Commit
63e3f3c4
authored
Sep 11, 2023
by
letaoqin
Browse files
Merge branch 'mha-train-develop-bwdopt-bias' into mha-train-develop-grad-bias
parents
592b0649
db579ac9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
87 additions
and
96 deletions
+87
-96
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+13
-14
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+31
-34
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+12
-13
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+31
-35
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
63e3f3c4
...
...
@@ -265,26 +265,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
};
// dP Gemm (type 1 rcc)
template
<
typename
BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2
=
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2
();
template
<
typename
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
&
b_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
...
...
@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
n
0_
n
1_
n
2_
n
3_
k2
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
N
0_
N
1_
N
2_
N
3_
K2
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
...
@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// set up dP Gemm (type 1 rcc)
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
>
;
// dP: blockwise gemm
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
...
...
@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
63e3f3c4
...
...
@@ -115,7 +115,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
...
...
@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
...
...
@@ -310,26 +310,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// dP Gemm (type 1 rcc, B in Vgpr)
template
<
typename
BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
();
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
...
...
@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
k
2_
n
0_
n
1_
n
2_
n
3_
k3
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
K
2_
N
0_
N
1_
N
2_
N
3_
K3
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
,
typename
BSrcBlockDesc_N0_K_N1
>
struct
Gemm2
{
private:
...
...
@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_n0_k_n1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_K0_M_K1
>
__host__
__device__
static
constexpr
auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
...
...
@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
{
const
auto
N0_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I0
);
const
auto
K_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I1
);
const
auto
N1_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I2
);
const
auto
N0_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I0
);
const
auto
K_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I1
);
const
auto
N1_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
b_b
lock
_d
esc_
n
0_
k_n1
,
BSrcB
lock
D
esc_
N
0_
K_N1
{}
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_pass_through_transform
(
K_
)),
// 128 // 128
...
...
@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// S Gemm (type 4 rcc, B in LDS)
template
<
typename
BSrcBlockDesc_K0_N_K1
>
struct
Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
...
@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_b
lock
_d
esc_
bk0_n_bk1
)
,
BSrcB
lock
D
esc_
K0_N_K1
,
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
b_b
lock
_d
esc_
bk0_n_bk1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
BSrcB
lock
D
esc_
K0_N_K1
{}
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
...
@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up dP Gemm (type 1 rcc)
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
>
;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
...
@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up S Gemm (type 4 rcc)
//
using
Gemm3
=
Gemm3
<
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
...
@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
//
// set up dQ Gemm (type 3 crr)
//
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
)
>
;
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
),
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
...
@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
63e3f3c4
...
...
@@ -264,26 +264,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
// dP Gemm (type 1 rcc)
template
<
typename
BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2
=
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2
();
template
<
typename
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_N0_N1_N2_N3_K2
&
b_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
...
...
@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
n
0_
n
1_
n
2_
n
3_
k2
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
N
0_
N
1_
N
2_
N
3_
K2
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_vgpr_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
...
@@ -1651,6 +1648,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// set up dP Gemm (type 1 rcc)
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_n0_n1_n2_n3_k2
)
>
;
// dP: blockwise gemm
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
63e3f3c4
...
...
@@ -114,7 +114,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
...
...
@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
...
...
@@ -309,26 +309,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// dP Gemm (type 1 rcc, B in Vgpr)
template
<
typename
BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B source matrix layout in VGPR
static
constexpr
auto
b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
();
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
...
...
@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
b_src_t
hread
_d
esc_
k
0_
k
1_
k
2_
n
0_
n
1_
n
2_
n
3_
k3
);
GetBThreadDescriptor_K0_N_K1
(
BSrcT
hread
D
esc_
K
0_
K
1_
K
2_
N
0_
N
1_
N
2_
N
3_
K3
{}
);
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
,
typename
BSrcBlockDesc_N0_K_N1
>
struct
Gemm2
{
private:
...
...
@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_n0_k_n1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_K0_M_K1
>
__host__
__device__
static
constexpr
auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
...
...
@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
{
const
auto
N0_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I0
);
const
auto
K_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I1
);
const
auto
N1_
=
b_b
lock
_d
esc_
n
0_
k_n1
.
GetLength
(
I2
);
const
auto
N0_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I0
);
const
auto
K_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I1
);
const
auto
N1_
=
BSrcB
lock
D
esc_
N
0_
K_N1
{}
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
b_b
lock
_d
esc_
n
0_
k_n1
,
BSrcB
lock
D
esc_
N
0_
K_N1
{}
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_pass_through_transform
(
K_
)),
// 128 // 128
...
...
@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// S Gemm (type 4 rcc, B in LDS)
template
<
typename
BSrcBlockDesc_K0_N_K1
>
struct
Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
...
@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_b
lock
_d
esc_
bk0_n_bk1
)
,
BSrcB
lock
D
esc_
K0_N_K1
,
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
b_b
lock
_d
esc_
bk0_n_bk1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
BSrcB
lock
D
esc_
K0_N_K1
{}
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_vgpr_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
...
@@ -1520,8 +1511,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0Operator
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size
;
...
...
@@ -1697,6 +1687,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up dP Gemm (type 1 rcc)
//
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
>
;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
...
@@ -1733,6 +1725,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up S Gemm (type 4 rcc)
//
using
Gemm3
=
Gemm3
<
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
...
...
@@ -1838,7 +1832,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// set up dQ Gemm (type 3 crr)
//
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
)
>
;
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
),
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment