Commit 35b2971e authored by danyao12's avatar danyao12
Browse files

fix bugs and optimize bwd qloop 2 kernels

parent 52478ac3
...@@ -73,7 +73,7 @@ __global__ void ...@@ -73,7 +73,7 @@ __global__ void
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -140,7 +140,7 @@ __global__ void ...@@ -140,7 +140,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -175,7 +175,7 @@ __global__ void ...@@ -175,7 +175,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
......
...@@ -1040,7 +1040,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1040,7 +1040,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ave_time = launch_kernel(integral_constant<bool, false>{}, ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
return ave_time; return ave_time;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -36,7 +36,8 @@ __global__ void ...@@ -36,7 +36,8 @@ __global__ void
kernel_grouped_multihead_attention_backward_ydotygrad_v1( kernel_grouped_multihead_attention_backward_ydotygrad_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -89,12 +90,13 @@ template <typename GridwiseGemm, ...@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -106,7 +108,8 @@ __global__ void ...@@ -106,7 +108,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -158,7 +161,7 @@ __global__ void ...@@ -158,7 +161,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -180,7 +183,6 @@ __global__ void ...@@ -180,7 +183,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -194,7 +196,7 @@ __global__ void ...@@ -194,7 +196,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -216,7 +218,6 @@ __global__ void ...@@ -216,7 +218,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -307,7 +308,7 @@ template <index_t NumDimG, ...@@ -307,7 +308,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -320,7 +321,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -320,7 +321,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -341,9 +342,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -341,9 +342,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> d_gs_ms_lengths;
std::vector<index_t> d_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock; const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw, if constexpr(GemmSpec == GemmSpecialization::MPadding ||
make_tuple(make_right_pad_transform(MRaw, MPad)), GemmSpec == GemmSpecialization::MNPadding ||
make_tuple(Sequence<0>{}), GemmSpec == GemmSpecialization::MKPadding ||
make_tuple(Sequence<0>{})); GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -725,14 +734,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -725,14 +734,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B DDataType,
DDataType, // datatype DYGridDesc_M_O,
YGridDesc_M_O, DGridDesc_M,
DGridDesc_M, BlockSize,
BlockSize, DMPerBlock,
DMPerBlock, DKPerBlock,
DKPerBlock>; Gemm1NPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
...@@ -776,7 +785,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -776,7 +785,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -956,7 +964,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -956,7 +964,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameters // D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeDGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N( const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -1001,7 +1009,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1001,7 +1009,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
raw_m_padded, raw_m_padded,
raw_n_padded, raw_n_padded,
p_d_grid, p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -1105,17 +1112,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1105,17 +1112,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ave_time = launch_kernel(); ave_time = launch_kernel();
} }
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1<
GroupKernelArg, GridwiseGemm,
AElementwiseOperation, GroupKernelArg,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
has_main_k_block_loop_, CElementwiseOperation,
Deterministic>; has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1139,11 +1148,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1139,11 +1148,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
else else
{ {
...@@ -1169,22 +1188,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1169,22 +1188,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
for(index_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1352,7 +1367,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1352,7 +1367,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1" str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -36,7 +36,8 @@ __global__ void ...@@ -36,7 +36,8 @@ __global__ void
kernel_grouped_multihead_attention_backward_ydotygrad_v2( kernel_grouped_multihead_attention_backward_ydotygrad_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -89,12 +90,13 @@ template <typename GridwiseGemm, ...@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -106,7 +108,8 @@ __global__ void ...@@ -106,7 +108,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -158,7 +161,7 @@ __global__ void ...@@ -158,7 +161,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -180,7 +183,6 @@ __global__ void ...@@ -180,7 +183,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -194,7 +196,7 @@ __global__ void ...@@ -194,7 +196,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -216,7 +218,6 @@ __global__ void ...@@ -216,7 +218,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -314,7 +315,7 @@ template <index_t NumDimG, ...@@ -314,7 +315,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -327,7 +328,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -327,7 +328,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -348,9 +349,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -348,9 +349,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> d_gs_ms_lengths;
std::vector<index_t> d_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock; const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw, if constexpr(GemmSpec == GemmSpecialization::MPadding ||
make_tuple(make_right_pad_transform(MRaw, MPad)), GemmSpec == GemmSpecialization::MNPadding ||
make_tuple(Sequence<0>{}), GemmSpec == GemmSpecialization::MKPadding ||
make_tuple(Sequence<0>{})); GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -733,14 +742,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -733,14 +742,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B DDataType,
DDataType, // datatype DYGridDesc_M_O,
YGridDesc_M_O, DGridDesc_M,
DGridDesc_M, BlockSize,
BlockSize, DMPerBlock,
DMPerBlock, DKPerBlock,
DKPerBlock>; Gemm1NPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
...@@ -784,7 +793,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -784,7 +793,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -870,6 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -870,6 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
d_grid_size_ = 0; d_grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -920,6 +929,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -920,6 +929,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart); const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(k_grid_desc_n_k)) * (Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(k_grid_desc_n_k)) *
...@@ -959,7 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -959,7 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameters // D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeDGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N( const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -1004,7 +1017,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1004,7 +1017,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
raw_m_padded, raw_m_padded,
raw_n_padded, raw_n_padded,
p_d_grid, p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -1107,17 +1119,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1107,17 +1119,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ave_time = launch_kernel(); ave_time = launch_kernel();
} }
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
GridwiseGemm, kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2<
GroupKernelArg, GridwiseGemm,
AElementwiseOperation, GroupKernelArg,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
has_main_k_block_loop_, CElementwiseOperation,
Deterministic>; has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1141,11 +1155,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1141,11 +1155,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
else else
{ {
...@@ -1171,7 +1195,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1171,7 +1195,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
...@@ -1181,11 +1207,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1181,11 +1207,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1358,7 +1379,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1358,7 +1379,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2" str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -41,7 +41,6 @@ template <typename InputDataType, ...@@ -41,7 +41,6 @@ template <typename InputDataType,
typename VGridDesc_O0_N_O1, typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O, typename YGridDesc_M_O,
typename LSEGridDesc_M, typename LSEGridDesc_M,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -83,7 +82,7 @@ template <typename InputDataType, ...@@ -83,7 +82,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1155,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1155,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
...@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -1206,7 +1205,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1206,7 +1205,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize()); p_d_grid, lse_grid_desc_m.GetElementSpaceSize()); // reuse lse grid descriptor
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1532,6 +1531,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1532,6 +1531,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
acc0_thread_origin[I5], acc0_thread_origin[I5],
acc0_thread_origin[I6])}; acc0_thread_origin[I6])};
auto d_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatD,
FloatGemmAcc,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -1651,11 +1669,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1651,11 +1669,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// load d and lse // load d and lse
// //
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, d_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
d_grid_buf, d_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4, lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf); y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf, lse_grid_buf,
...@@ -1743,56 +1761,64 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1743,56 +1761,64 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if(p_z_grid) if constexpr(IsDropout)
{ {
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto m_global = m_local + m_block_data_idx_on_grid; auto n_local =
auto n_global = n_local + n_block_data_idx_on_grid; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto n_global = n_local + n_block_data_idx_on_grid;
n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
auto global_elem_id = n_global; // unique element global 1d id
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id =
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
decltype(z_tenor_buffer),
decltype(position_offset), blockwise_dropout
true>( .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); decltype(z_tenor_buffer),
decltype(position_offset),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, true>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
} make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
else z_tenor_buffer,
{ z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
ignore = z_grid_buf; z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto m_global = m_local + m_block_data_idx_on_grid; auto n_local =
auto n_global = n_local + n_block_data_idx_on_grid; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(position_offset),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY) // dS = P * (dP - Y_dot_dY)
...@@ -1965,6 +1991,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1965,6 +1991,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0));
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop } while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write // shuffle dK&dV and write
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -41,7 +41,6 @@ template <typename InputDataType, ...@@ -41,7 +41,6 @@ template <typename InputDataType,
typename VGridDesc_N0_O_N1, typename VGridDesc_N0_O_N1,
typename YGridDesc_M_O, typename YGridDesc_M_O,
typename LSEGridDesc_M, typename LSEGridDesc_M,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -91,7 +90,7 @@ template <typename InputDataType, ...@@ -91,7 +90,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1110,6 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1110,6 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
...@@ -1135,7 +1135,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1135,7 +1135,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -1161,7 +1160,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1161,7 +1160,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize()); p_d_grid, lse_grid_desc_m.GetElementSpaceSize()); // reuse lse grid descriptor
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1516,6 +1515,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1516,6 +1515,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
acc0_thread_origin[I5], acc0_thread_origin[I5],
acc0_thread_origin[I6])}; acc0_thread_origin[I6])};
auto d_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatD,
FloatGemmAcc,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -1612,11 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1612,11 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// load d and lse // load d and lse
// //
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, d_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
d_grid_buf, d_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4, lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf); y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf, lse_grid_buf,
...@@ -1706,55 +1724,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1706,55 +1724,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if(p_z_grid) if constexpr(IsDropout)
{ {
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto m_global = m_local + m_block_data_idx_on_grid; auto n_local =
auto n_global = n_local + n_block_data_idx_on_grid; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto n_global = n_local + n_block_data_idx_on_grid;
n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
auto global_elem_id = n_global; // unique element global 1d id
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id =
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
decltype(z_tenor_buffer),
decltype(position_offset), blockwise_dropout
true>( .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); decltype(z_tenor_buffer),
decltype(position_offset),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, true>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
} make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
else z_tenor_buffer,
{ z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
ignore = z_grid_buf; z_grid_buf);
}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; else
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; {
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; ignore = z_grid_buf;
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
// gemm dV // gemm dV
...@@ -2005,6 +2031,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -2005,6 +2031,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0));
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -26,7 +26,8 @@ template <typename InputDataType, ...@@ -26,7 +26,8 @@ template <typename InputDataType,
typename DGridDesc_M, typename DGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock> index_t NPerBlock,
index_t NPadded>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -103,7 +104,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -103,7 +104,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n) MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n)
{ {
// should rewrite BlockToCTileMap_M00_N0_M01Adapt // should rewrite BlockToCTileMap_M00_N0_M01Adapt
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, 1024, YGridDesc_M_N>(y_grid_desc_m_n); return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPadded, YGridDesc_M_N>(y_grid_desc_m_n);
} }
using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment