Commit 39dedce7 authored by rocking's avatar rocking
Browse files

[What] Rename MakeMeanVarDescriptor_M_N

[Why] Prepare to add count version of make descriptor
parent 48c1b923
......@@ -286,7 +286,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
template <typename LayOut>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
static auto MakeEHGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
{
const auto grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
......@@ -310,26 +310,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
return DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
template <typename LayOut, typename DoPads, index_t XPerTile, index_t YPerTile>
static auto MakeDescriptor_X_Y(index_t X, index_t Y)
template <typename LayOut, typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
{
const auto grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(X, Y), make_tuple(Y, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(X, Y), make_tuple(I1, X));
}
}();
return PadTensorDescriptor(grid_desc_m_n, make_tuple(XPerTile, YPerTile), DoPads{});
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <index_t XPerTile>
......@@ -344,17 +335,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock =
decltype(MakeDescriptor_X_Y<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using GemmMeanVarCountGridDesc_M_NBlock = decltype(
MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeDescriptor_X_Y<HLayout,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
decltype(MakeMeanVarDescriptor_M_N<HLayout,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -464,14 +455,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
gemm_mean_var_count_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
h_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
......@@ -487,15 +478,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
gemm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeDescriptor_X_Y<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(
gemm_mean_var_count_grid_desc_m_nblock_ = DeviceOp::
MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(
MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeDescriptor_X_Y<HLayout,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
DeviceOp::MakeMeanVarDescriptor_M_N<HLayout,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, gemm_nblock_);
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -507,7 +499,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E/F/G
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment