Commit 14d29856 authored by rocking's avatar rocking
Browse files

Pad different size for E and H in layernorm kernel according to different block tile

parent d78877a7
......@@ -202,9 +202,9 @@ template <typename ALayout,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
......@@ -249,8 +249,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
CDEElementwiseOperation,
HElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
// EDataType, MeanDataType and VarDataType must be the same.
// eg. M, N, K = [1, 1, 1],
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
......@@ -274,8 +274,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static constexpr auto matrix_padder = MatrixPadder<GemmSpec, index_t, index_t, index_t>{
GemmMPerBlock, GemmNPerBlock, GemmKPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
......@@ -313,21 +313,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename LayOut>
static auto MakeEHGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
{
const auto grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(Stride, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(I1, Stride));
}
}();
return matrix_padder.PadCDescriptor_M_N(grid_desc_mraw_nraw);
// Only support row major for E and H
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
......@@ -337,8 +329,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same<tensor_layout::gemm::RowMajor, DLayout>::value);
return DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
return DeviceOp::
MakeEHGridDescriptor_M_N<Sequence<true, true>, GemmMPerBlock, GemmNPerBlock>(
MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
......@@ -373,11 +368,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using GemmMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using GemmMeanVarGridDesc_M_NBlock = decltype(
MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
using GemmCountGridDesc_M_NBlock =
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using GemmCountGridDesc_M_NBlock = decltype(
MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
......@@ -390,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<Sequence<true, true>, 1, 1>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -412,9 +407,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
GemmCountGridDesc_M_NBlock,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
AK1,
BK1,
MPerXDL,
......@@ -503,7 +498,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
gemm_e_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
GemmMPerBlock,
GemmNPerBlock>(MRaw, NRaw, StrideH)},
layernorm_e_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, NRaw, StrideH)},
gemm_mean_var_grid_desc_m_nblock_{},
gemm_count_grid_desc_m_nblock_{},
layernorm_mean_var_grid_desc_m_nblock_{},
......@@ -512,12 +515,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
h_grid_desc_m_n_{
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{
GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -525,16 +533,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw},
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)},
epsilon_{static_cast<AccDataType>(epsilon)}
{
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, 1>(
MRaw, gemm_nblock_);
gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, 1>(
MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ =
......@@ -551,7 +559,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
......@@ -559,14 +566,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
GemmMPerBlock,
GemmNPerBlock>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E/mean/var/count
if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
gemm_e_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
......@@ -575,7 +584,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
gemm_e_grid_desc_m_n_);
gemm_mean_var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
......@@ -593,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
std::cout << "E[M, N]: " << gemm_e_grid_desc_m_n_ << std::endl;
std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl;
}
......@@ -614,7 +623,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_;
EHGridDesc_M_N gemm_e_grid_desc_m_n_;
EHGridDesc_M_N layernorm_e_grid_desc_m_n_;
GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_;
GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_;
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
......@@ -663,13 +673,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.gemm_e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
}
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
......@@ -763,7 +773,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.layernorm_e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_,
......@@ -1043,9 +1053,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< GemmMPerBlock << ", "
<< GemmNPerBlock << ", "
<< GemmKPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment