Commit dcfe312b authored by ltqin's avatar ltqin
Browse files

change some functions name

parent a71a3f65
......@@ -31,7 +31,7 @@ namespace device {
template <typename GridwiseGemm,
typename InputDataType,
typename DDataType,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch>
......@@ -43,7 +43,7 @@ __global__ void
const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
......@@ -112,7 +112,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v1(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -337,7 +337,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -349,7 +349,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -692,7 +692,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -973,7 +973,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_;
// element-wise op
......@@ -1037,7 +1037,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InputDataType,
DDataType,
typename GridwiseYDotYGrad::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::DGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>;
......@@ -1067,31 +1067,32 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.batch_count_;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v1<
GridwiseGemm,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1396,7 +1397,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -30,7 +30,7 @@ namespace device {
template <typename GridwiseGemm,
typename InputDataType,
typename DDataType,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch>
......@@ -42,7 +42,7 @@ __global__ void
const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
......@@ -111,7 +111,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v2(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -343,7 +343,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -355,7 +355,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -698,7 +698,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -986,7 +986,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_;
// element-wise op
......@@ -1050,7 +1050,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InputDataType,
DDataType,
typename GridwiseYDotYGrad::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::DGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>;
......@@ -1084,31 +1084,32 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v2<
GridwiseGemm,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1427,7 +1428,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -74,7 +74,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_);
#else
......@@ -311,7 +311,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -323,7 +323,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -638,7 +638,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -758,8 +758,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_;
index_t d_block_start_, d_block_end_;
};
......@@ -950,7 +950,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_oblock_operblock =
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o);
......@@ -992,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
p_d_grid,
d_grid_desc_m,
d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_oblock_operblock,
d_y_grid_desc_mblock_mperblock_nblock_nperblock,
d_num_blocks_per_batch,
d_block_start,
d_block_end});
......@@ -1164,9 +1164,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
for(index_t i = 0; i < arg.group_count_; i++)
{
// TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
......@@ -1334,7 +1340,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -74,7 +74,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_);
#else
......@@ -318,7 +318,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -330,7 +330,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -638,7 +638,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -765,8 +765,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_;
index_t d_block_start_, d_block_end_;
};
......@@ -952,7 +952,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_oblock_operblock =
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o);
......@@ -994,7 +994,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
p_d_grid,
d_grid_desc_m,
d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_oblock_operblock,
d_y_grid_desc_mblock_mperblock_nblock_nperblock,
d_num_blocks_per_batch,
d_block_start,
d_block_end});
......@@ -1168,6 +1168,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
......@@ -1340,7 +1345,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -83,7 +83,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
......@@ -91,7 +91,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
......@@ -47,15 +47,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
return false;
}
const auto M = c_grid_desc_m_n.GetLength(I0);
if(M < MPerBlock)
{
return false;
}
if(M % MPerBlock != 0)
// const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
if(N < NPerBlock)
{
return false;
}
// std::cout << "m: " << M <<" n: " << N << std::endl;
// if(M < MPerBlock)
// {
// return false;
// }
// if(M % MPerBlock != 0)
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
......@@ -69,14 +75,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
const auto y_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
return y_grid_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
......@@ -102,7 +108,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
c_grid_desc_m_n);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
......@@ -139,15 +145,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid,
void* __restrict__ p_shared,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m,
const DefaultBlock2CTileMap& block_2_ctile_map)
{
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
p_ygrad_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize());
......@@ -158,8 +164,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I0),
y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2))))
make_tuple(y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
......@@ -193,7 +199,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatD,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
......@@ -201,7 +207,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
......@@ -217,12 +223,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
index_t oblock_idx = 0;
do
{
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
......@@ -237,11 +243,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
});
});
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 1, 0));
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
} while(oblock_idx < y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2));
auto d_grid_desc_mblock_mperblock = MakeORSGridDescriptor_MBlock_MPerBlock(d_grid_desc_m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment