"vscode:/vscode.git/clone" did not exist on "0747f76a86b63baf8ff15de8556d1f832df1b2cd"
Commit 39dedce7 authored by rocking's avatar rocking
Browse files

[What] Rename MakeMeanVarDescriptor_M_N

[Why] Prepare to add count version of make descriptor
parent 48c1b923
...@@ -286,7 +286,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -286,7 +286,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
} }
template <typename LayOut> template <typename LayOut>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride) static auto MakeEHGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
{ {
const auto grid_desc_mraw_nraw = [&]() { const auto grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
...@@ -310,26 +310,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -310,26 +310,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]); return DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
template <typename LayOut, typename DoPads, index_t XPerTile, index_t YPerTile> template <typename LayOut, typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeDescriptor_X_Y(index_t X, index_t Y) static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
{ {
const auto grid_desc_m_n = [&]() { const auto grid_desc_m_n =
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value) make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
{ return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
return make_naive_tensor_descriptor(make_tuple(X, Y), make_tuple(Y, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(X, Y), make_tuple(I1, X));
}
}();
return PadTensorDescriptor(grid_desc_m_n, make_tuple(XPerTile, YPerTile), DoPads{});
} }
template <index_t XPerTile> template <index_t XPerTile>
...@@ -344,17 +335,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -344,17 +335,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock = using GemmMeanVarCountGridDesc_M_NBlock = decltype(
decltype(MakeDescriptor_X_Y<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1)); MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock = using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeDescriptor_X_Y<HLayout, decltype(MakeMeanVarDescriptor_M_N<HLayout,
Sequence<true, true>, Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1)); LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -464,14 +455,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -464,14 +455,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
gemm_mean_var_count_grid_desc_m_nblock_{}, gemm_mean_var_count_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{}, layernorm_mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{ gamma_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)}, DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{ beta_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)}, DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -487,15 +478,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -487,15 +478,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
gemm_mean_var_count_grid_desc_m_nblock_ = gemm_mean_var_count_grid_desc_m_nblock_ = DeviceOp::
DeviceOp::MakeDescriptor_X_Y<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>( MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ = layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeDescriptor_X_Y<HLayout, DeviceOp::MakeMeanVarDescriptor_M_N<HLayout,
Sequence<true, true>, Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); LayernormBlockTileSize_M_N::At(1)>(
MRaw, gemm_nblock_);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -507,7 +499,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -507,7 +499,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// D desc // D desc
ds_grid_desc_m_n_(i) = ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
}); });
// populate desc for Ds/E/F/G // populate desc for Ds/E/F/G
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment