Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b5fbb74b
Commit
b5fbb74b
authored
Feb 27, 2023
by
ltqin
Browse files
add GemmDataType
parent
9096e2af
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
50 deletions
+40
-50
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+10
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+0
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+26
-40
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
b5fbb74b
...
@@ -61,6 +61,7 @@ using YElementOp = PassThrough;
...
@@ -61,6 +61,7 @@ using YElementOp = PassThrough;
using
VElementOp
=
Scale
;
using
VElementOp
=
Scale
;
using
DataType
=
F16
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
@@ -97,6 +98,7 @@ using DeviceGemmInstance =
...
@@ -97,6 +98,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
...
@@ -164,6 +166,7 @@ using DeviceGemmInstance =
...
@@ -164,6 +166,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
b5fbb74b
...
@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
...
@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_a_grid
,
...
@@ -171,6 +171,7 @@ template <index_t NumDimG,
...
@@ -171,6 +171,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
DataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
...
@@ -595,9 +596,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -595,9 +596,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
DataType
,
// TODO: distinguish A/B datatype
DataType
,
// TODO: distinguish A/B datatype
LSE
DataType
,
Gemm
DataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
b5fbb74b
...
@@ -602,7 +602,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -602,7 +602,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
is_lse_storing_
=
false
;
is_lse_storing_
=
false
;
}
}
}
}
void
Print
()
const
void
Print
()
const
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
b5fbb74b
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
namespace
ck
{
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
DataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatLSE
,
typename
FloatLSE
,
...
@@ -85,21 +86,6 @@ template <typename DataType,
...
@@ -85,21 +86,6 @@ template <typename DataType,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
{
template
<
typename
T
>
struct
TypeMap
{
using
type
=
T
;
};
#if defined(__gfx90a_masking__)
template
<
>
struct
TypeMap
<
ck
::
half_t
>
{
using
type
=
ck
::
bhalf_t
;
};
#endif
using
LDSDataType
=
typename
TypeMap
<
DataType
>::
type
;
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -141,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -141,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
@@ -157,7 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -157,7 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
{
{
constexpr
auto
mfma
=
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
@@ -471,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -471,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
Gemm
DataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -496,7 +482,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -496,7 +482,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
Gemm
DataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -513,12 +499,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -513,12 +499,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
index_t
KPack
=
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
LDS
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -580,7 +566,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -580,7 +566,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
FloatGemmAcc
,
LDS
DataType
,
Gemm
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -599,7 +585,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -599,7 +585,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
Gemm
DataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
...
@@ -630,11 +616,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -630,11 +616,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
LDS
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -650,7 +636,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -650,7 +636,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
GemmKPack
,
GemmKPack
,
true
,
// TransposeC
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
,
// AMmaKStride
GemmKPack
*
XdlopsGemm
<
LDS
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
GemmKPack
*
XdlopsGemm
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
.
K0PerXdlops
/* BMmaKStride */
>
;
.
K0PerXdlops
/* BMmaKStride */
>
;
};
};
...
@@ -682,7 +668,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -682,7 +668,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
BThreadClusterLengths
=
using
BThreadClusterLengths
=
...
@@ -807,7 +793,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -807,7 +793,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
LDS
DataType
,
Gemm
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
ElementwiseOp
,
ElementwiseOp
,
...
@@ -837,7 +823,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -837,7 +823,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
Gemm
DataType
,
GridDesc_M0_O_M1
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
decltype
(
b_block_desc_m0_o_m1
),
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
...
@@ -854,7 +840,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -854,7 +840,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
LDS
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_block_desc_m0_o_m1
),
decltype
(
b_block_desc_m0_o_m1
),
...
@@ -1095,7 +1081,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1095,7 +1081,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
auto
b2_block_desc_m0_o_m1
=
static
constexpr
auto
b2_block_desc_m0_o_m1
=
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
LDS
DataType
)
>
{};
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
Gemm
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
@@ -1131,13 +1117,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1131,13 +1117,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
LDS
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
gemm1_bytes_end
=
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
LDS
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
LDS
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
SharedMemTrait
::
reduction_space_size_aligned
)
*
...
@@ -1243,11 +1229,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1243,11 +1229,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm0: LDS allocation for A and B: be careful of alignment
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
// Gemm0: gridwise GEMM pipeline
...
@@ -1339,11 +1325,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1339,11 +1325,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and LDS allocation for B
// Gemm1: VGPR allocation for A and LDS allocation for B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
LDS
DataType
>
(
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
// dQ: transform input and output tensor descriptors
...
@@ -1535,11 +1521,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1535,11 +1521,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm2: LDS allocation for A and B: be careful of alignment
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
// dV: transform input and output tensor descriptors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment