Commit 14d29856 authored by rocking's avatar rocking
Browse files

Pad different size for E and H in layernorm kernel according to different block tile

parent d78877a7
...@@ -202,9 +202,9 @@ template <typename ALayout, ...@@ -202,9 +202,9 @@ template <typename ALayout,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t GemmMPerBlock,
index_t NPerBlock, index_t GemmNPerBlock,
index_t KPerBlock, index_t GemmKPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t MPerXDL, index_t MPerXDL,
...@@ -274,8 +274,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -274,8 +274,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder = MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; GemmMPerBlock, GemmNPerBlock, GemmKPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
...@@ -313,21 +313,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -313,21 +313,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
template <typename LayOut> template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeEHGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride) static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
{
const auto grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(Stride, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(I1, Stride)); // Only support row major for E and H
} const auto grid_desc_m_n =
}(); make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
return matrix_padder.PadCDescriptor_M_N(grid_desc_mraw_nraw);
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
...@@ -337,8 +329,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -337,8 +329,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same<tensor_layout::gemm::RowMajor, DLayout>::value);
return DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]); return DeviceOp::
MakeEHGridDescriptor_M_N<Sequence<true, true>, GemmMPerBlock, GemmNPerBlock>(
MRaws[i], NRaws[i], DsStride[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
...@@ -373,11 +368,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -373,11 +368,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarGridDesc_M_NBlock = using GemmMeanVarGridDesc_M_NBlock = decltype(
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1)); MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
using GemmCountGridDesc_M_NBlock = using GemmCountGridDesc_M_NBlock = decltype(
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1)); MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
using LayernormMeanVarGridDesc_M_NBlock = using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>, decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
...@@ -390,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -390,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
LayernormBlockTileSize_M_N::At(1)>(1, 1)); LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<Sequence<true, true>, 1, 1>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -412,9 +407,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -412,9 +407,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
GemmCountGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, GemmMPerBlock,
NPerBlock, GemmNPerBlock,
KPerBlock, GemmKPerBlock,
AK1, AK1,
BK1, BK1,
MPerXDL, MPerXDL,
...@@ -503,7 +498,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -503,7 +498,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, gemm_e_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
GemmMPerBlock,
GemmNPerBlock>(MRaw, NRaw, StrideH)},
layernorm_e_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, NRaw, StrideH)},
gemm_mean_var_grid_desc_m_nblock_{}, gemm_mean_var_grid_desc_m_nblock_{},
gemm_count_grid_desc_m_nblock_{}, gemm_count_grid_desc_m_nblock_{},
layernorm_mean_var_grid_desc_m_nblock_{}, layernorm_mean_var_grid_desc_m_nblock_{},
...@@ -512,12 +515,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -512,12 +515,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)}, DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{ beta_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)}, DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{
GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -525,16 +533,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -525,16 +533,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
MRaw_{MRaw}, MRaw_{MRaw},
NRaw_{NRaw}, NRaw_{NRaw},
KRaw_{KRaw}, KRaw_{KRaw},
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)},
epsilon_{static_cast<AccDataType>(epsilon)} epsilon_{static_cast<AccDataType>(epsilon)}
{ {
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1. // We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
gemm_mean_var_grid_desc_m_nblock_ = gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>( DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, 1>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
gemm_count_grid_desc_m_nblock_ = gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>( DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, 1>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ = layernorm_mean_var_grid_desc_m_nblock_ =
...@@ -551,7 +559,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -551,7 +559,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer // D pointer
...@@ -559,14 +566,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -559,14 +566,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// D desc // D desc
ds_grid_desc_m_n_(i) = ds_grid_desc_m_n_(i) =
DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
GemmMPerBlock,
GemmNPerBlock>(MRaw, NRaw, StrideDs[i]);
}); });
// populate desc for Ds/E/mean/var/count // populate desc for Ds/E/mean/var/count
if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_, if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
e_grid_desc_m_n_, gemm_e_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
...@@ -575,7 +584,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -575,7 +584,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); gemm_e_grid_desc_m_n_);
gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = gemm_mean_var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
...@@ -593,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -593,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; std::cout << "E[M, N]: " << gemm_e_grid_desc_m_n_ << std::endl;
std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl; std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl;
} }
...@@ -614,7 +623,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -614,7 +623,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_; EHGridDesc_M_N gemm_e_grid_desc_m_n_;
EHGridDesc_M_N layernorm_e_grid_desc_m_n_;
GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_; GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_;
GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_; GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_;
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_; LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
...@@ -663,13 +673,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -663,13 +673,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.gemm_e_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
} }
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
const auto M = arg.h_grid_desc_m_n_.GetLength(I0); const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
const auto N = arg.h_grid_desc_m_n_.GetLength(I1); const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
...@@ -763,7 +773,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -763,7 +773,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.p_gamma_grid_, arg.p_gamma_grid_,
arg.p_beta_grid_, arg.p_beta_grid_,
arg.p_h_grid_, arg.p_h_grid_,
arg.e_grid_desc_m_n_, arg.layernorm_e_grid_desc_m_n_,
arg.h_grid_desc_m_n_, arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_grid_desc_m_nblock_, arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_, arg.layernorm_count_grid_desc_m_nblock_,
...@@ -1043,9 +1053,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -1043,9 +1053,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle" str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << GemmMPerBlock << ", "
<< NPerBlock << ", " << GemmNPerBlock << ", "
<< KPerBlock << ", " << GemmKPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment