"...composable_kernel_rocm.git" did not exist on "7790e8c3f781ec86385c39b9346bdf6fef0a56d3"
Unverified Commit e439b369 authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

Attn bwd develop qloop (#720)



* fix decoder tensor transfer related issues

* prototype1 Q loop direction w/ layout change

* remove useless templates

* add OutputDataType&Deterministic for pt1q1

* add OutputDataType&Deterministic for pt1q2

---------
Co-authored-by: default avatardanyao12 <danyao12@amd.com>
parent 26115ce7
......@@ -10,6 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward batched_multi
add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_batched_multihead_attention_backward_v4 batched_multihead_attention_backward_v4.cpp)
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
......
......@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename DataType,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename AElementwiseOperation,
......@@ -46,22 +47,23 @@ template <typename GridwiseGemm,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const InputDataType* __restrict__ p_b1_grid,
const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -78,6 +80,7 @@ __global__ void
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_drop,
......@@ -111,36 +114,73 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw);
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
0);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -173,7 +213,8 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
......@@ -233,6 +274,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
......@@ -599,7 +641,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......@@ -663,22 +707,23 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument
struct Argument : public BaseArgument
{
Argument(
const DataType* p_a_grid,
const DataType* p_b_grid,
const InputDataType* p_a_grid,
const InputDataType* p_b_grid,
ZDataType* p_z_grid,
const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS
const InputDataType* p_b1_grid,
const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -822,16 +867,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
// pointers
const DataType* p_a_grid_;
const DataType* p_b_grid_;
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_;
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_;
DataType* p_vgrad_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -896,14 +941,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_) * arg.batch_count_;
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_)) *
arg.batch_count_;
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
DataType,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
AElementwiseOperation,
......@@ -921,44 +969,46 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_z_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_lse_grid_,
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_z_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_lse_grid_,
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......@@ -1068,16 +1118,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
static auto MakeArgument(
const DataType* p_a,
const DataType* p_b,
const InputDataType* p_a,
const InputDataType* p_b,
ZDataType* p_z,
const DataType* p_b1,
const DataType* p_c,
const InputDataType* p_b1,
const InputDataType* p_c,
const LSEDataType* p_lse,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -1183,16 +1233,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c),
static_cast<const InputDataType*>(p_b1),
static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid),
static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
......
......@@ -20,7 +20,9 @@
namespace ck {
template <typename DataType,
template <typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -85,6 +87,7 @@ template <typename DataType,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
......@@ -439,7 +442,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1),
......@@ -464,7 +467,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
......@@ -489,7 +492,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1),
......@@ -514,7 +517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1),
......@@ -806,7 +809,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
OutputDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
ElementwiseOp, // CElementwiseOperation
......@@ -1117,7 +1120,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
......@@ -1234,16 +1237,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -1265,7 +1268,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw)
const index_t NRaw,
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
......@@ -1297,9 +1301,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force n_block_data_idx_on_grid into SGPR
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
......@@ -1510,7 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
3,
m2,
1,
false>{
true /* ResetCoordAfterRun */>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(0, // mblock
acc0_thread_origin[I0], // mrepeat
......@@ -1554,7 +1560,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
......@@ -1574,15 +1580,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, // MBlockId
block_work_idx[I0], // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
......@@ -1698,7 +1704,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
......@@ -1749,11 +1755,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
true /* ResetCoordAfterRun */>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
......@@ -1761,6 +1767,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
......@@ -1788,6 +1799,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
// move slice window
gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
GemmBlockwiseCopy::gemm_tile_q_block_slice_copy_step); // step M
gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1,
GemmBlockwiseCopy::gemm_tile_ygrad_block_slice_copy_step); // step M
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
Gemm1::c_block_slice_copy_step); // step M
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(1, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(1, 0, 0, 0));
continue;
}
......@@ -2400,7 +2425,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
DataType, // typename DstData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(vgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......@@ -2411,7 +2436,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// shuffle: threadwise copy C from VGPR to LDS
......@@ -2458,7 +2483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
DataType, // typename DstData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(kgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......@@ -2469,7 +2494,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment