Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
32203c74
Commit
32203c74
authored
Nov 01, 2023
by
danyao12
Browse files
bwd hd256
parent
339b86e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4316 additions
and
1 deletion
+4316
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+15
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v3.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v3.hpp
+1509
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v3.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v3.hpp
+2792
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
32203c74
...
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
256
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -38,6 +38,7 @@ Kernel outputs:
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -150,6 +151,19 @@ using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 64, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 64, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#elif(DIM <= 256)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 64, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 64, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 32, 256, 32, 64, 8, 8, 2, 32, 32, 1, 1, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
32
,
128
,
32
,
256
,
32
,
64
,
8
,
8
,
2
,
32
,
32
,
1
,
1
,
8
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#endif
// Ref Gemm0: S = alpha * Q * K^T
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v3.hpp
0 → 100644
View file @
32203c74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
YGradGridDesc_M0_O_M1
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v3
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_b1_grid
,
const
InputDataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
index_t
raw_m_padded
,
const
index_t
raw_n_padded
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
gkv_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
gkv_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1GradBasePtr
(
g_idx
)));
ck
::
philox
ph
(
seed
,
0
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
bgrad_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1grad_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
bgrad_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1grad_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
0
);
}
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_z_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
p_lse_grid
;
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
bgrad_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1grad_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
lse_grid_desc_m
;
ignore
=
ygrad_grid_desc_m0_o_m1
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
p_drop
;
ignore
=
seed
;
ignore
=
offset
;
ignore
=
raw_m_padded
;
ignore
=
raw_n_padded
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm2KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
std
::
is_void
<
D1DataType
>::
value
,
"Acc1 Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
V_O1
=
BK1
;
static
constexpr
index_t
Y_O1
=
AK1
;
static
constexpr
index_t
Y_M1
=
B1K1
;
static
constexpr
auto
padder
=
GemmGemmPadder
<
GemmSpec
,
Number
<
MPerBlock
>
,
Number
<
NPerBlock
>
,
Number
<
KPerBlock
>
,
Number
<
Gemm1NPerBlock
>>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpec
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
//
// dV = P^T * dY
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
}
template
<
typename
YGridDesc_M_O
>
static
auto
MakeYGradGridDescriptor_M0_O_M1
(
const
YGridDesc_M_O
&
ygrad_grid_desc_m_o
)
{
const
auto
M
=
ygrad_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
Y_M0
=
M
/
Y_M1
;
return
transform_tensor_descriptor
(
ygrad_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_M0
,
Y_M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
// N_O to O0_N_O1; to refactor
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
// D0 in Gemm0 C position
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeC0GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeC0GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
lse_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
lse_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
lse_grid_desc_mraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
return
MaskUpperTriangleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromBottomRight
)
{
return
MaskUpperTriangleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
()
{}
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
{
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
b1_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3
<
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
ZDataType
,
GemmDataType
,
GemmAccDataType
,
CShuffleDataType
,
LSEDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
Gemm2KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm2NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
D0BlockTransferSrcScalarPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
true
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
,
Deterministic
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InputDataType
*
p_a_grid
,
const
InputDataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
InputDataType
*
p_b1_grid
,
const
InputDataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
InputDataType
*
p_ygrad_grid
,
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_d0_grid_
{
p_acc0_bias
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_lse_grid_
{
p_lse_grid
},
p_ygrad_grid_
{
p_ygrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
bgrad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1grad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
k_grid_desc_n_k_
{
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y_grid_desc_m_o_
)},
// batch offsets
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeC0GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
bgrad_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
b1grad_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
b1_gs_gemm1ns_gemm1ks_lengths
[
NumDimG
+
NumDimO
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b_nz_kz_strides_
{
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
b1_nz_kz_strides_
{
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
-
1
],
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
ignore
=
p_acc1_bias
;
ignore
=
p_d1grad_grid
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
y_grid_desc_m_o_
))
{
y_grid_desc_mblock_mperblock_oblock_operblock_
=
GridwiseGemm
::
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
y_grid_desc_m_o_
);
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_desc_m_n
=
MakeD0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
bgrad_grid_desc_g_n_k_
,
b1grad_grid_desc_g_n_k_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
}
void
Print
()
const
{
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_o_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"k_grid_desc_n_k_: "
<<
k_grid_desc_n_k_
.
GetLength
(
I0
)
<<
", "
<<
k_grid_desc_n_k_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"bgrad_grid_desc_g_n_k_: "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// bgrad_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1grad_grid_desc_g_n_k_: "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1grad_grid_desc_g_n_k_.Print();
}
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
D0DataType
*
p_d0_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1_
;
// batch offsets
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
b_nz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_kz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! unsupported argument"
);
}
const
index_t
grid_size
=
(
Deterministic
?
1
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
))
*
arg
.
batch_count_
;
// Gemm0_K
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v3
<
GridwiseGemm
,
InputDataType
,
D0DataType
,
OutputDataType
,
ZDataType
,
LSEDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
bgrad_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1grad_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
,
arg
.
m_raw_padded_
,
arg
.
n_raw_padded_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
{
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
// TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
{
return
false
;
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
arg
.
d0_n_length_stride_
[
0
]
%
D0BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
D0BlockTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
];
const
auto
NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
];
const
auto
KzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
2
];
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
NzRaw
:
Gemm1NzRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if
constexpr
(
is_same
<
OutputDataType
,
half_t
>::
value
||
is_same
<
OutputDataType
,
bhalf_t
>::
value
)
{
if
(
KzRaw
%
2
!=
0
)
{
std
::
cout
<<
"K_q must be a multiple of 2"
<<
std
::
endl
;
return
false
;
}
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_kz_strides_
[
1
]
:
arg
.
b1_nz_kz_strides_
[
0
];
const
auto
c_stride_lowest
=
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_m_o_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
const
InputDataType
*
p_b
,
ZDataType
*
p_z
,
const
InputDataType
*
p_b1
,
const
InputDataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
InputDataType
*
p_ygrad_grid
,
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
return
Argument
{
p_a
,
p_b
,
p_z
,
p_b1
,
p_c
,
p_lse
,
p_ygrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
seeds
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// FIXME: constness
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_z
,
const
void
*
p_b1
,
const
void
*
p_c
,
const
void
*
p_lse
,
const
void
*
p_ygrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
static_cast
<
const
InputDataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
InputDataType
*>
(
p_b1
),
static_cast
<
const
InputDataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
InputDataType
*>
(
p_ygrad_grid
),
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
acc1_bias_gs_ms_gemm1ns_strides
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
seeds
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
// override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm2KPerBlock
<<
", "
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v3.hpp
0 → 100644
View file @
32203c74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace
ck
{
template
<
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatLSE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
SElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
KGridDesc_N_K
,
typename
D0GridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_N0_O_N1
,
typename
YGridDesc_M_O
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm2KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V3
{
static_assert
(
AK1Value
%
ABlockTransferDstScalarPerVector_AK1
==
0
);
static_assert
(
BK1Value
%
BBlockTransferDstScalarPerVector_BK1
==
0
);
static_assert
(
B1K1Value
%
B1BlockTransferDstScalarPerVector_BK1
==
0
);
static_assert
(
Gemm1NPerBlock
%
KPerBlock
==
0
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_16x8() generates 16 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
// 32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
M3
,
M4
,
M5
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
1
,
3
,
5
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
template
<
typename
AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
const
AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
)
{
// acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a_src_thread_desc_k0_m_k1
// m0_m1_m2_m3 -> k0
// n0_n1_n2 -> m
// m4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
const
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
const
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
const
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
const
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
const
auto
m3
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
const
auto
m4
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
const
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
return
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
,
m3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
)),
make_pass_through_transform
(
m4
)),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
,
7
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B1K0
,
Number
<
Gemm1NPerBlock
>
{},
B1K1
),
make_tuple
(
Number
<
Gemm1NPerBlock
+
B1BlockLdsExtraN
>
{}
*
B1K1
,
B1K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
Gemm0NWaves
;
constexpr
index_t
NWave
=
Gemm0MWaves
;
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
template
<
typename
Gemm2Param
>
__host__
__device__
static
constexpr
auto
GetA2BlockDescriptor_K0_M_K1
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
Gemm2Param
::
A_K0
>
{},
Number
<
Gemm2Param
::
Gemm2_M
>
{},
Number
<
Gemm2Param
::
A_K1
>
{}),
make_tuple
(
Number
<
Gemm2Param
::
Gemm2_M
+
Gemm2Param
::
A_LdsPad
>
{}
*
Number
<
Gemm2Param
::
A_K1
>
{},
Number
<
Gemm2Param
::
A_K1
>
{},
I1
));
}
template
<
typename
Gemm2Param
>
__host__
__device__
static
constexpr
auto
GetB2BlockDescriptor_K0_N_K1
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
Gemm2Param
::
B_K0
>
{},
Number
<
Gemm2Param
::
Gemm2_N
>
{},
Number
<
Gemm2Param
::
B_K1
>
{}),
make_tuple
(
Number
<
Gemm2Param
::
Gemm2_N
+
Gemm2Param
::
B_LdsPad
>
{}
*
Number
<
Gemm2Param
::
B_K1
>
{},
Number
<
Gemm2Param
::
B_K1
>
{},
I1
));
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
YGridDesc_M_O
&
y_grid_desc_m_o
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
O
!=
K
)
{
std
::
cerr
<<
"O = "
<<
O
<<
" K = "
<<
K
<<
std
::
endl
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
std
::
cerr
<<
"M = "
<<
M
<<
" O = "
<<
O
<<
" y_grid_desc_m_o = "
<<
y_grid_desc_m_o
.
GetLength
(
I0
)
<<
" , "
<<
y_grid_desc_m_o
.
GetLength
(
I1
)
<<
std
::
endl
;
std
::
cerr
<<
"Un-matched sizes!"
<<
std
::
endl
;
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
O
%
Gemm1NPerBlock
==
0
))
{
std
::
cerr
<<
"M = "
<<
M
<<
" N = "
<<
N
<<
" O = "
<<
O
<<
std
::
endl
;
std
::
cerr
<<
"MPerBlock = "
<<
MPerBlock
<<
" NPerBlock = "
<<
NPerBlock
<<
" KPerBlock = "
<<
KPerBlock
<<
std
::
endl
;
std
::
cerr
<<
"Un-aligned sizes!"
<<
std
::
endl
;
return
false
;
}
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
const
YGridDesc_M_O
&
y_grid_desc_m_o
)
{
const
auto
M
=
y_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
y_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
OBlock
=
O
/
Gemm1NPerBlock
;
const
auto
y_grid_desc_mblock_mperblock_oblock_operblock
=
transform_tensor_descriptor
(
y_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
OBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
y_grid_desc_mblock_mperblock_oblock_operblock
;
}
template
<
typename
SrcBlockwiseGemm
>
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
auto
SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
SrcBlockwiseGemm
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// M0 MXdlPerWave, M1 MWave, M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
const
auto
M0
=
SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
.
GetLength
(
I0
);
const
auto
M1
=
SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
.
GetLength
(
I2
);
const
auto
M2
=
SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
.
GetLength
(
I4
);
const
auto
M3
=
SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
.
GetLength
(
I5
);
const
auto
M4
=
SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
.
GetLength
(
I6
);
const
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
{}));
return
lse_grid_desc_mb_m0_m1_m2_m3_m4
;
}
__device__
static
auto
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
)
{
const
auto
O0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
O1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
auto
O
=
O0
*
O1
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
OBlock
=
O
/
Gemm1NPerBlock
;
const
auto
k_grid_desc_n_o
=
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
N
),
make_merge_transform_v3_division_mod
(
make_tuple
(
O0
,
O1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
k_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
OBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
__device__
static
auto
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
)
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
const
auto
N
=
N0
*
N1
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
OBlock
=
O
/
Gemm1NPerBlock
;
const
auto
v_grid_desc_n_o
=
transform_tensor_descriptor
(
v_grid_desc_n0_o_n1
,
make_tuple
(
make_pass_through_transform
(
O
),
make_merge_transform_v3_division_mod
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
v_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
OBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
__device__
static
auto
MakeQGradGridDesc_M_K
(
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
)
{
const
auto
K_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
M
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
KGridDesc_N_K
&
k_grid_desc_n_k
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
Gemm1NPerBlock
,
KGridDesc_N_K
>
(
k_grid_desc_n_k
);
}
using
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
=
remove_cvref_t
<
decltype
(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
YGridDesc_M_O
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
KGridDesc_N_K
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
// S / dP Gemm (type 1 rcc)
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
GridDesc_K0_M_K1
>
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
;
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
};
// dV / dK Gemm (type 2 rrr)
template
<
typename
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
struct
Gemm1
{
private:
static
constexpr
auto
m0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I0
);
static
constexpr
auto
n0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I1
);
static
constexpr
auto
m1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I2
);
static
constexpr
auto
n1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I3
);
static
constexpr
auto
m2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I4
);
static
constexpr
auto
m3
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I5
);
static
constexpr
auto
m4
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I6
);
static
constexpr
auto
n2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I7
);
// M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
static
constexpr
auto
M3
=
ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2
{}.
GetLength
(
I5
);
public:
static
constexpr
auto
AThreadSliceLength_K0
=
Number
<
Gemm1KPerBlock
/
m4
/
M3
>
{};
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
n0
*
n1
*
n2
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
m4
>
{};
// A source matrix layout in AccVGPR
static
constexpr
auto
a_src_thread_desc_k0_m_k1
=
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
AThreadSliceLength_K0
,
AThreadSliceLength_M
,
AThreadSliceLength_K1
));
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
1
,
1
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
auto
ASrcScalarPerVector
=
m4
;
using
AThreadSliceLengths_K0_M_K1
=
decltype
(
a_thread_desc_k0_m_k1
.
GetLengths
());
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
GemmDataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
ElementwiseOp
,
AThreadSliceLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
2
,
ASrcScalarPerVector
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static
constexpr
auto
a_block_slice_copy_step
=
make_tuple
(
AThreadSliceLength_K0
,
I0
,
I0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
NPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
GemmKPack
,
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
*
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
.
K0PerXdlops
/* BMmaKStride */
>
;
};
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
struct
Gemm2Params
{
static
constexpr
index_t
Gemm2_M
=
MPerBlock
;
// 64
static
constexpr
index_t
Gemm2_K
=
NPerBlock
;
// 128
static
constexpr
index_t
Gemm2_N
=
Gemm1NPerBlock
;
// 128
static
constexpr
index_t
Sum_K
=
Gemm2KPerBlock
;
static
constexpr
index_t
A_K1
=
8
;
// dS will be row-major
static
constexpr
index_t
A_K0
=
Sum_K
/
A_K1
;
static
constexpr
index_t
A_LdsPad
=
0
;
// how many multiples of K1 per M * K1 elements
static
constexpr
index_t
B_K1
=
B1K1
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
B_K0
=
Sum_K
/
B_K1
;
static
constexpr
index_t
B_LdsPad
=
0
;
// how many multiples of K1 per N * K1 elements
static_assert
(
Sum_K
%
NPerXdl
==
0
,
""
);
static
constexpr
index_t
BSrcVectorDim
=
1
;
// Gemm2_N dimension
static
constexpr
index_t
BSrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
Gemm2_N
/
Gemm2NXdlPerWave
/
NPerXdl
;
// 1 // 2
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
// 4 // 2
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
// 1 // 1
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
// 1 // 1
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_K0
,
Gemm2_N
,
B_K1
>
;
using
BThreadClusterLengths
=
Sequence
<
BlockSize
/
(
Gemm2_N
/
BSrcScalarPerVector
),
Gemm2_N
/
BSrcScalarPerVector
,
1
>
;
using
BThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_K0_M1_K1_M2_K2
()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
k
=
Sum_K
-
1
;
constexpr
index_t
k2
=
k
%
NPerXdl
;
constexpr
index_t
k1
=
k
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
k0
=
k
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Gemm2_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return
Sequence
<
m0
,
k0
,
m1
,
k1
,
m2
,
k2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_K0_M1_K1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetABlockSliceLengths_M0_K0_M1_K1_M2_K2
().
At
(
I
);
},
Number
<
4
>
{});
}
using
ABlockSliceLengths_M0_K0_M1_K1
=
decltype
(
GetABlockSliceLengths_M0_K0_M1_K1
());
//(2, 1, 1, 2) //(4, 1, 1, 2)
};
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
struct
Gemm2
{
private:
static
constexpr
auto
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
ASrcBlockwiseGemm
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
static
constexpr
auto
M0
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
// repeat
static
constexpr
auto
N0
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
static
constexpr
auto
M1
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
// wave
static
constexpr
auto
N1
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
static
constexpr
auto
M2
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
// xdl
static
constexpr
auto
M3
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
static
constexpr
auto
M4
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
static
constexpr
auto
N2
=
a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static
constexpr
auto
a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
ASrcBlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_k0_n_k1
=
GetB2BlockDescriptor_K0_N_K1
<
Gemm2Params
>
();
__host__
__device__
static
constexpr
auto
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2
()
{
const
auto
K0_
=
a_block_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M_
=
a_block_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1_
=
a_block_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
a_block_desc_m_k
=
transform_tensor_descriptor
(
a_block_desc_k0_m_k1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0_
,
K1_
)),
make_pass_through_transform
(
M_
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return
transform_tensor_descriptor
(
a_block_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
I1
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__
__device__
static
auto
MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2
()
{
const
auto
a_thread_origin_on_block_idx
=
ASrcBlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
a_block_slice_lengths_m0_k0_m1_k1
=
typename
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
{};
// mrepeat, nrepeat,
// mwaves, nwaves,
return
make_tuple
(
a_thread_origin_on_block_idx
[
I0
],
// mrepeat
a_thread_origin_on_block_idx
[
I1
],
// nrepeat
a_thread_origin_on_block_idx
[
I2
]
%
a_block_slice_lengths_m0_k0_m1_k1
[
I2
],
// mwave
a_thread_origin_on_block_idx
[
I3
]
%
a_block_slice_lengths_m0_k0_m1_k1
[
I3
],
// nwave
a_thread_origin_on_block_idx
[
I4
],
// xdlops
a_thread_origin_on_block_idx
[
I5
],
a_thread_origin_on_block_idx
[
I6
],
a_thread_origin_on_block_idx
[
I7
]);
}
static
constexpr
auto
a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2
=
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2
();
using
ASrcBlockSliceWindowIterator
=
SpaceFillingCurve
<
Sequence
<
M0
,
N0
,
M1
,
N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
typename
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
,
false
>
;
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
GemmDataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2
),
ElementwiseOp
,
Sequence
<
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
::
At
(
I1
),
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
typename
Gemm2Params
::
BBlockSliceLengths
,
typename
Gemm2Params
::
BThreadClusterLengths
,
typename
Gemm2Params
::
BThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_k0_n_k1
),
typename
Gemm2Params
::
BThreadClusterArrangeOrder
,
// access order == thread order
Sequence
<
1
,
0
,
2
>
,
Gemm2Params
::
BSrcVectorDim
,
2
,
// DstVectorDim
Gemm2Params
::
BSrcScalarPerVector
,
Gemm2Params
::
B_K1
,
1
,
1
,
true
,
true
,
1
>
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmMRepeat
,
Gemm2Params
::
GemmNRepeat
,
Gemm2Params
::
GemmKPack
,
true
>
;
// TranspossC
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm2Params
::
B_K0
,
0
,
0
);
static
constexpr
auto
c_block_slice_copy_step
=
make_multi_index
(
-
Gemm2Params
::
GemmMRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
-
NPerBlock
/
Gemm2Params
::
B_K1
,
0
,
0
);
template
<
typename
CGradDesc_M_N
>
__host__
__device__
static
auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
(
const
CGradDesc_M_N
&
c_grid_desc_m_n
)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params
::
GemmMWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params
::
GemmNWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
BlockwiseGemm
{}.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
return
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
;
}
static
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
BlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
__host__
__device__
static
auto
GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
()
{
return
to_multi_index
(
BlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
));
}
template
<
typename
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
OutputDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
ElementwiseOp
,
// CElementwiseOperation
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
InputDataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
{
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template
<
typename
YGradGridDesc_M0_O_M1_
>
__device__
static
auto
MakeYGradGridDesc_O0_M_O1
(
const
YGradGridDesc_M0_O_M1_
&
ygrad_grid_desc_m0_o_m1
)
{
const
auto
M0
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I1
);
const
auto
M1
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I2
);
constexpr
auto
Y_O1
=
AK1
;
const
auto
Y_O0
=
O
/
Y_O1
;
const
auto
ygrad_grid_desc_o0_m_o1
=
transform_tensor_descriptor
(
ygrad_grid_desc_m0_o_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_O0
,
Y_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
ygrad_grid_desc_o0_m_o1
;
}
template
<
typename
VGridDesc_N0_O_N1_
>
__device__
static
auto
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
constexpr
auto
V_O1
=
BK1
;
const
auto
V_O0
=
O
/
V_O1
;
const
auto
v_grid_desc_o0_n_o1
=
transform_tensor_descriptor
(
v_grid_desc_n0_o_n1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
V_O0
,
V_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
v_grid_desc_o0_n_o1
;
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct
QGradGemmTile_M_K_N
{
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
constexpr
auto
K_N1
=
B1K1
;
const
auto
K_N0
=
N
/
K_N1
;
const
auto
k_grid_desc_n0_k_n1
=
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K_N0
,
K_N1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
k_grid_desc_n0_k_n1
;
}
};
struct
KGradGemmTile_N_K_M
{
// B position
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
auto
MakeQGridDesc_M0_K_M1
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
Q_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
Q_K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
constexpr
auto
Q_M1
=
B1K1
;
const
auto
Q_M0
=
M
/
Q_M1
;
const
auto
q_grid_desc_m0_k_m1
=
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Q_M0
,
Q_M1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
Q_K0
,
Q_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
q_grid_desc_m0_k_m1
;
}
};
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
static
constexpr
auto
D0M0
=
Number
<
MPerBlock
>
{}
/
Number
<
MPerXdl
>
{};
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M0
,
D0M1
,
D0M2
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
,
5
>
{},
Sequence
<
1
,
4
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0Operator
{
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
NPerXdl
==
32
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
()
{
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
constexpr
auto
d0_n0_n1_m0_m1_m2
=
transform_tensor_descriptor
(
d0_raw_m0_n_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M1
/
I2
,
I2
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_pass_through_transform
(
D0M2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
}
static
constexpr
auto
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_src_desc_n0_n1_m0_m1_m2
=
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
5
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_src_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
2
>
;
using
D0GradThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0grad_block_dst_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0GradBlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
b2_block_desc_k0_n_k1
=
GetB2BlockDescriptor_K0_N_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b2_block_space_size_aligned
=
math
::
integer_least_multiple
(
b2_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
a2_block_space_offset
=
0
;
static
constexpr
auto
b2_block_space_offset
=
a2_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
a2_block_space_size_aligned
+
SharedMemTrait
::
b2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
softmax_bytes_end
,
d0_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
const
InputDataType
*
__restrict__
p_k_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
SElementwiseOperation
&
s_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
VGridDesc_N0_O_N1
&
vgrad_grid_desc_n0_o_n1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
ck
::
philox
&
ph
,
const
index_t
z_random_matrix_offset
,
const
index_t
raw_n_padded
,
const
index_t
block_idx_n
)
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
uint8_t
p_dropout_in_uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
255.0
)));
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
k_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_k_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
const
auto
v_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_v_grid
,
v_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
kgrad_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
block_work_idx_n
=
Deterministic
?
block_idx_n
:
block_work_idx
[
I0
];
// HACK: this force n_block_data_idx_on_grid into SGPR
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx_n
*
NPerBlock
);
const
index_t
num_gemm0_m_block_outer_loop
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
)
/
MPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
MPerBlock
/
Gemm1KPerBlock
;
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr)
//
// set up S / dP Gemm (type 1 rcc)
//
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gemm0_gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
// S: A matrix blockwise copy
auto
s_gemm_tile_q_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
MPerBlock
*
(
num_gemm0_m_block_outer_loop
-
1
),
0
),
// will loop over GemmM dimension
a_element_op
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: B matrix blockwise copy
auto
s_gemm_tile_k_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
s_gemm_tile_q_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
-
MPerBlock
,
0
);
const
auto
s_gemm_tile_k_block_reset_copy_step
=
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
0
,
0
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
// dP: transform input and output tensor descriptors
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP: A matrix blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
MPerBlock
*
(
num_gemm0_m_block_outer_loop
-
1
),
0
),
// will loop over GemmM dimension
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: B matrix blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
-
MPerBlock
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
make_multi_index
(
-
v_grid_desc_o0_n_o1
.
GetLength
(
I0
),
0
,
0
);
const
index_t
num_o_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
//
// set up dV / dK Gemm (type 2 rrr)
//
using
Gemm1
=
Gemm1
<
decltype
(
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()),
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
())
>
;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
GemmDataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// dV: A matrix blockwise copy
auto
vgrad_gemm_tile_p_blockwise_copy
=
typename
Gemm1
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
Relu
>{
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
// dV: B matrix blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
MPerBlock
/
B1K1
*
(
num_gemm0_m_block_outer_loop
-
1
),
0
,
0
),
b1_element_op
,
Gemm1
::
b_block_desc_bk0_n_bk1
,
// there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
const
auto
vgrad_gemm_tile_ygrad_block_next_copy_step
=
make_multi_index
(
-
2
*
MPerBlock
/
B1K1
,
0
,
0
);
// dV: blockwise gemm
auto
vgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
vgrad_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
vgrad_grid_desc_n0_o_n1
);
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
KGradGemmTile_N_K_M
::
MakeQGridDesc_M0_K_M1
(
q_grid_desc_k0_m_k1
);
// dK: A matrix blockwise copy
auto
kgrad_gemm_tile_sgrad_blockwise_copy
=
typename
Gemm1
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
PassThrough
>{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dK: B matrix blockwise copy
auto
kgrad_gemm_tile_q_blockwise_copy
=
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
q_grid_desc_m0_k_m1
)>(
q_grid_desc_m0_k_m1
,
make_multi_index
(
MPerBlock
/
B1K1
*
(
num_gemm0_m_block_outer_loop
-
1
),
0
,
0
),
b1_element_op
,
Gemm1
::
b_block_desc_bk0_n_bk1
,
// there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
const
auto
kgrad_gemm_tile_q_block_next_copy_step
=
make_multi_index
(
-
2
*
MPerBlock
/
B1K1
,
0
,
0
);
// dK: blockwise gemm
auto
kgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
kgrad_thread_buf
=
kgrad_blockwise_gemm
.
GetCThreadBuffer
();
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
kgrad_grid_desc_k0_n_k1
);
//
// set up dQ Gemm (type 3 crr)
//
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
Gemm2
::
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
Gemm2
::
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
// // dQ: transform input and output tensor descriptors
// const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
// Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dQ: transform input and output tensor descriptors
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
// dQ: A matrix VGPR-to-LDS blockwise copy
auto
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
PassThrough
>{
Gemm2
::
a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ: B matrix global-to-LDS blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
k_grid_desc_n0_k_n1
,
make_multi_index
(
n_block_data_idx_on_grid
/
Gemm2Params
::
B_K1
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm2
::
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dQ: blockwise gemm
auto
qgrad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
// dQ: transform output tensor descriptors
const
auto
qgrad_grid_desc_m_k
=
MakeQGradGridDesc_M_K
(
q_grid_desc_k0_m_k1
);
const
auto
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
(
qgrad_grid_desc_m_k
);
// dQ: C VGPR-to-global copy
const
auto
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
()
+
make_multi_index
((
num_gemm0_m_block_outer_loop
-
1
)
*
Gemm2Params
::
GemmMRepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
qgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
),
decltype
(
scale_rp_dropout
)>(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4
,
scale_rp_dropout
);
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
()
/
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I0
);
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I1
);
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I2
);
constexpr
auto
tn1
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I3
);
constexpr
auto
tm2
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I4
);
constexpr
auto
tm3
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I5
);
constexpr
auto
tm4
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I6
);
constexpr
auto
tn2
=
thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2
.
At
(
I7
);
// get acc0 thread map
constexpr
auto
n0_m_n1_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
tn0
*
tn1
,
tn2
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_n0_m_n1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
tn0
*
tn1
,
tm0
*
tm1
*
tm2
*
tm3
*
tm4
,
tn2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_m_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
n0_m_n1_to_m_n_adaptor
,
threadid_to_n0_m_n1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
m0
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
n0
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
m1
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
n1
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
m2
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
m3
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
m4
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
n2
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
*
tm3
*
tm4
,
tn0
*
tn1
*
tn2
));
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
*
m3
*
m4
,
n0
*
n1
*
n2
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
constexpr
auto
lse_thread_desc_mb_m0_m1_m2_m3_m4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
m0
,
m1
,
m2
,
m3
,
m4
));
auto
lse_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatLSE
>
(
lse_thread_desc_mb_m0_m1_m2_m3_m4
.
GetElementSpaceSize
());
auto
acc0_thread_origin
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
auto
lse_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatLSE
,
FloatLSE
,
decltype
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
),
decltype
(
lse_thread_desc_mb_m0_m1_m2_m3_m4
),
Sequence
<
1
,
m0
,
m1
,
m2
,
m3
,
m4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
],
// mperxdl
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MGroupNum
m3
,
// MInputNum
m4
,
// RegisterNum
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
uint8_t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
z_tensor_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
uint8_t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m3
,
// NGroupNum
m4
,
// NInputNum
n2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
// MBlockId
block_work_idx_n
,
// NBlockId
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
0
,
// MPerXdl
wave_m_n_id
[
I0
],
//
0
,
//
wave_m_n_id
[
I1
]),
// NPerXdl
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
P_M3
=
p_block_lengths
[
I5
];
constexpr
auto
P_M4
=
p_block_lengths
[
I6
];
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
I0
,
I0
,
I0
)
+
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
,
P_M3
,
P_M4
));
constexpr
auto
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
=
lse_thread_desc_mb_m0_m1_m2_m3_m4
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
),
decltype
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
),
Sequence
<
1
,
m0
,
m1
,
m2
,
m3
,
m4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
],
// mperxdl
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
.
GetElementSpaceSize
());
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Operator
::
D0BlockwiseCopyGlobalToLds
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0grad_block_copy_lds_to_global
=
typename
D0Operator
::
D0GradBlockwiseCopyLdsToGlobal
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
// Initialize dK&dV
kgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
do
{
auto
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm0_m_block_outer_index
*
MPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
continue
;
}
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
ygrad_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ygrad_thread_buf
);
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
offset
=
y_thread_desc_m0_m1_o0_o1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
offset
>
{}]
*
ygrad_thread_buf
[
Number
<
offset
>
{}];
});
});
// blockwise reduction using atomic_add
block_sync_lds
();
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
*
p_dropout
);
// p_dropoutD1
});
block_sync_lds
();
// distribute y_dot_ygrad to threads;
// LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
,
y_dot_ygrad_block_accum_buf
,
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
block_sync_lds
();
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_grid_buf
,
lse_thread_desc_mb_m0_m1_m2_m3_m4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
// S = Q * K^T
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q_grid_desc_k0_m_k1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
s_gemm_tile_q_blockwise_copy
,
q_grid_buf
,
gemm0_a_block_buf
,
Gemm0
::
a_block_slice_copy_step
,
k_grid_desc_k0_n_k1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
s_gemm_tile_k_blockwise_copy
,
k_grid_buf
,
gemm0_b_block_buf
,
Gemm0
::
b_block_slice_copy_step
,
s_blockwise_gemm
,
s_slash_p_thread_buf
,
num_k_block_main_loop
);
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
M3
=
c_block_lengths
[
I5
];
constexpr
auto
M4
=
c_block_lengths
[
I6
];
constexpr
auto
N2
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
bool
masked_flag
=
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
);
s_element_op
(
s_slash_p_thread_buf
(
i
),
masked_flag
?
-
ck
::
NumericLimits
<
float
>::
Infinity
()
:
s_slash_p_thread_buf
[
i
]);
});
}
else
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0_grid
!=
nullptr
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
if
constexpr
(
IsDropout
)
{
if
(
p_z_grid
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
else
{
ignore
=
z_grid_buf
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
// gemm dV
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
Gemm1
::
b_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
vgrad_gemm_tile_p_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
i
,
s_slash_p_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
block_sync_lds
();
vgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
vgrad_thread_buf
);
block_sync_lds
();
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
Gemm1
::
b_block_slice_copy_step
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
});
}
// tail
{
vgrad_gemm_tile_p_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
Number
<
num_gemm1_k_block_inner_loop
-
1
>
{},
s_slash_p_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
block_sync_lds
();
vgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
vgrad_thread_buf
);
}
}
// end gemm dV
// gemm dP
block_sync_lds
();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
ygrad_grid_buf
,
gemm0_a_block_buf
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
gemm0_b_block_buf
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
num_o_block_main_loop
);
// dS = P * (dP - Y_dot_dY)
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
constexpr
auto
pgrad_thread_tile_iterator
=
pgrad_blockwise_gemm
.
MakeCThreadTileIterator
();
constexpr
auto
pgrad_thread_idx_to_m_n_adaptor
=
pgrad_blockwise_gemm
.
MakeCThreadIndexAdaptor8DTo2D
();
static_for
<
0
,
pgrad_thread_tile_iterator
.
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
pgrad_thread_idx
=
pgrad_thread_tile_iterator
.
GetIndex
(
i
);
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
bool
undropped_flag
=
s_slash_p_thread_buf
[
i
]
>=
0
;
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
undropped_flag
?
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}])
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0grad_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_gemm2_loop
=
NPerBlock
/
Gemm2Params
::
Sum_K
;
static_assert
(
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetNumOfAccess
()
==
num_gemm2_loop
,
""
);
// TODO: tune gemm2 pipeline
// gemm dQ
// dQ = scalar * dS * K
qgrad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dQ
// load QGrad Gemm B
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
// load QGrad Gemm A
const
auto
sgrad_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
sgrad_slice_idx
[
I2
],
sgrad_slice_idx
[
I2
]
+
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
::
At
(
I3
));
block_sync_lds
();
// sync before write
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
sgrad_slice_idx
[
I0
],
sgrad_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
Gemm2
::
a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2
,
gemm2_a_block_buf
);
}
// k slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm2
::
b_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_k0_n_k1
,
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
qgrad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
qgrad_thread_buf
);
});
// end gemm dQ
// atomic_add dQ
qgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
qgrad_thread_buf
,
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
qgrad_grid_buf
);
// gemm dK
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm1
::
b_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
kgrad_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
kgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
i
,
sgrad_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
block_sync_lds
();
kgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
kgrad_thread_buf
);
block_sync_lds
();
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm1
::
b_block_slice_copy_step
);
kgrad_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
});
}
// tail
{
kgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
Number
<
num_gemm1_k_block_inner_loop
-
1
>
{},
sgrad_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
block_sync_lds
();
kgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
kgrad_thread_buf
);
}
}
// end gemm dK
// move slice window
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
s_gemm_tile_q_block_reset_copy_step
);
// rewind K and step M
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
s_gemm_tile_k_block_reset_copy_step
);
// rewind K
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O and step M
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm2
::
b_block_reset_copy_step
);
// rewind N
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
kgrad_gemm_tile_q_block_next_copy_step
);
// step M
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
vgrad_gemm_tile_ygrad_block_next_copy_step
);
// step M
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
-
1
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
// shuffle dK&dV and write
{
static_assert
(
NXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
Gemm0NWaves
;
constexpr
index_t
NWave
=
Gemm0MWaves
;
// TODO: hacky, fix it!
// thread desc same with kgrad_blockwise_gemm
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// block desc same with kgrad_blockwise_gemm
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
vgrad_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I5
);
constexpr
auto
N3
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I6
);
constexpr
auto
N4
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
)),
// M2 = MPerXdl
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
,
// N2 * N3 * N4 = NPerXdl
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// index same with kgrad_blockwise_gemm
const
auto
c_thread_mtx_on_block
=
vgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
vgrad_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
SElementwiseOperation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
}};
// shuffle: blockwise copy C from LDS to global
auto
vgrad_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
OutputDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
vgrad_grid_desc_nblock_nperblock_oblock_operblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
vgrad_grid_desc_nblock_nperblock_oblock_operblock
,
make_multi_index
(
block_work_idx_n
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// shuffle: threadwise copy C from VGPR to LDS
auto
kgrad_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
SElementwiseOperation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
scale_rp_dropout
};
// shuffle: blockwise copy C from LDS to global
auto
kgrad_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
OutputDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
kgrad_grid_desc_nblock_nperblock_oblock_operblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
kgrad_grid_desc_nblock_nperblock_oblock_operblock
,
make_multi_index
(
block_work_idx_n
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
NXdlPerWave
,
Gemm1NXdlPerWave
,
1
,
1
,
1
,
N2
,
1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
1
,
N2
,
1
,
N4
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
NPerBlock
,
1
,
Gemm1NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
// dK
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
kgrad_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
kgrad_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
kgrad_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
kgrad_grid_desc_nblock_nperblock_oblock_operblock
,
kgrad_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
kgrad_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_nblock_nperblock_oblock_operblock
,
c_global_step
);
}
});
// dV
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
vgrad_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
vgrad_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
vgrad_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
vgrad_grid_desc_nblock_nperblock_oblock_operblock
,
vgrad_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
vgrad_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_nblock_nperblock_oblock_operblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment