"vscode:/vscode.git/clone" did not exist on "45adb736e7294dd28c2a353ef598cf1802bd6b75"
Commit dcfe312b authored by ltqin's avatar ltqin
Browse files

change some functions name

parent a71a3f65
...@@ -31,7 +31,7 @@ namespace device { ...@@ -31,7 +31,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename InputDataType, typename InputDataType,
typename DDataType, typename DDataType,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_M, typename DGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch> typename ComputeBasePtrOfStridedBatch>
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid, DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m, const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
...@@ -112,7 +112,7 @@ __global__ void ...@@ -112,7 +112,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -337,7 +337,7 @@ template <index_t NumDimG, ...@@ -337,7 +337,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -349,7 +349,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -349,7 +349,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -692,7 +692,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -692,7 +692,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -973,7 +973,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -973,7 +973,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_; typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_; d_y_grid_desc_mblock_mperblock_oblock_operblock_;
// element-wise op // element-wise op
...@@ -1037,7 +1037,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1037,7 +1037,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InputDataType, InputDataType,
DDataType, DDataType,
typename GridwiseYDotYGrad:: typename GridwiseYDotYGrad::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::DGridDesc_M, DeviceOp::DGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap, typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>; ComputeBasePtrOfStridedBatch>;
...@@ -1067,7 +1067,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1067,7 +1067,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.batch_count_; arg.batch_count_;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v1<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
OutputDataType, OutputDataType,
...@@ -1396,7 +1397,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1396,7 +1397,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -30,7 +30,7 @@ namespace device { ...@@ -30,7 +30,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename InputDataType, typename InputDataType,
typename DDataType, typename DDataType,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_M, typename DGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch> typename ComputeBasePtrOfStridedBatch>
...@@ -42,7 +42,7 @@ __global__ void ...@@ -42,7 +42,7 @@ __global__ void
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid, DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m, const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
...@@ -111,7 +111,7 @@ __global__ void ...@@ -111,7 +111,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v2(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -343,7 +343,7 @@ template <index_t NumDimG, ...@@ -343,7 +343,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -355,7 +355,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -355,7 +355,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -698,7 +698,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -698,7 +698,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -986,7 +986,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -986,7 +986,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_; typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_; d_y_grid_desc_mblock_mperblock_oblock_operblock_;
// element-wise op // element-wise op
...@@ -1050,7 +1050,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1050,7 +1050,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InputDataType, InputDataType,
DDataType, DDataType,
typename GridwiseYDotYGrad:: typename GridwiseYDotYGrad::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::DGridDesc_M, DeviceOp::DGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap, typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>; ComputeBasePtrOfStridedBatch>;
...@@ -1084,7 +1084,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1084,7 +1084,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v2<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
OutputDataType, OutputDataType,
...@@ -1427,7 +1428,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1427,7 +1428,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2" str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -74,7 +74,7 @@ __global__ void ...@@ -74,7 +74,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset, arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_, arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_); arg_ptr[group_id].d_block_2_ctile_map_);
#else #else
...@@ -311,7 +311,7 @@ template <index_t NumDimG, ...@@ -311,7 +311,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -323,7 +323,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -323,7 +323,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -638,7 +638,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -638,7 +638,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -758,8 +758,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -758,8 +758,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
DDataType* p_d_grid_; DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_; typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_; d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_; index_t d_num_blocks_per_batch_;
index_t d_block_start_, d_block_end_; index_t d_block_start_, d_block_end_;
}; };
...@@ -950,7 +950,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -950,7 +950,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
const auto d_block_2_ctile_map = const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o); GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_oblock_operblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); y_grid_desc_m_o);
...@@ -992,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -992,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
p_d_grid, p_d_grid,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_oblock_operblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
d_num_blocks_per_batch, d_num_blocks_per_batch,
d_block_start, d_block_start,
d_block_end}); d_block_end});
...@@ -1164,9 +1164,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1164,9 +1164,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
for(index_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1334,7 +1340,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1334,7 +1340,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -74,7 +74,7 @@ __global__ void ...@@ -74,7 +74,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset, arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_, arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_); arg_ptr[group_id].d_block_2_ctile_map_);
#else #else
...@@ -318,7 +318,7 @@ template <index_t NumDimG, ...@@ -318,7 +318,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -330,7 +330,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -330,7 +330,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -638,7 +638,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -638,7 +638,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -765,8 +765,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -765,8 +765,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
DDataType* p_d_grid_; DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_; typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_; d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_; index_t d_num_blocks_per_batch_;
index_t d_block_start_, d_block_end_; index_t d_block_start_, d_block_end_;
}; };
...@@ -952,7 +952,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -952,7 +952,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto d_block_2_ctile_map = const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o); GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_oblock_operblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); y_grid_desc_m_o);
...@@ -994,7 +994,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -994,7 +994,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
p_d_grid, p_d_grid,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_oblock_operblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
d_num_blocks_per_batch, d_num_blocks_per_batch,
d_block_start, d_block_start,
d_block_end}); d_block_end});
...@@ -1168,6 +1168,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1168,6 +1168,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1340,7 +1345,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1340,7 +1345,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2" str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -83,7 +83,7 @@ template <typename InputDataType, ...@@ -83,7 +83,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
...@@ -91,7 +91,7 @@ template <typename InputDataType, ...@@ -91,7 +91,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
...@@ -47,15 +47,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -47,15 +47,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
return false; return false;
} }
const auto M = c_grid_desc_m_n.GetLength(I0); // const auto M = c_grid_desc_m_n.GetLength(I0);
if(M < MPerBlock) const auto N = c_grid_desc_m_n.GetLength(I1);
{ if(N < NPerBlock)
return false;
}
if(M % MPerBlock != 0)
{ {
return false; return false;
} }
// std::cout << "m: " << M <<" n: " << N << std::endl;
// if(M < MPerBlock)
// {
// return false;
// }
// if(M % MPerBlock != 0)
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
...@@ -69,14 +75,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -69,14 +75,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor( const auto y_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock; return y_grid_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -102,7 +108,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -102,7 +108,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype( using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -139,15 +145,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -139,15 +145,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid, FloatD* __restrict__ p_d_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m, const DGridDesc_M& d_grid_desc_m,
const DefaultBlock2CTileMap& block_2_ctile_map) const DefaultBlock2CTileMap& block_2_ctile_map)
{ {
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_ygrad_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize()); p_d_grid, d_grid_desc_m.GetElementSpaceSize());
...@@ -158,8 +164,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -158,8 +164,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple(y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I0), make_tuple(y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2)))) y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ {
return; return;
} }
...@@ -193,7 +199,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -193,7 +199,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType, InputDataType,
FloatD, FloatD,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
...@@ -201,7 +207,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -201,7 +207,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */, true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock, false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_thread_data_on_grid_idx); y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
...@@ -217,12 +223,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -217,12 +223,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
index_t oblock_idx = 0; index_t oblock_idx = 0;
do do
{ {
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_grid_buf, y_grid_buf,
y_thread_desc_m0_m1_o0_o1, y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
y_thread_buf); y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock,
ygrad_grid_buf, ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1, y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
...@@ -237,11 +243,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -237,11 +243,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
}); });
}); });
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 1, 0)); make_multi_index(0, 0, 1, 0));
oblock_idx++; oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2)); } while(oblock_idx < y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2));
auto d_grid_desc_mblock_mperblock = MakeORSGridDescriptor_MBlock_MPerBlock(d_grid_desc_m); auto d_grid_desc_mblock_mperblock = MakeORSGridDescriptor_MBlock_MPerBlock(d_grid_desc_m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment