Commit 5e215d49 authored by rocking's avatar rocking
Browse files

Refine the MakeDescriptor

parent 328cc7f4
......@@ -315,104 +315,45 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{});
}
template <typename LayOut>
static auto MakeGemmMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
template <typename LayOut, typename DoPads, index_t XPerTile, index_t YPerTile>
static auto MakeDescriptor_X_Y(index_t X, index_t Y)
{
const auto grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(NBlock, I1));
return make_naive_tensor_descriptor(make_tuple(X, Y), make_tuple(Y, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(I1, M));
return make_naive_tensor_descriptor(make_tuple(X, Y), make_tuple(I1, X));
}
}();
return PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerBlock, NBlock), Sequence<true, false>{});
return PadTensorDescriptor(grid_desc_m_n, make_tuple(XPerTile, YPerTile), DoPads{});
}
template <typename LayOut>
static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
template <index_t XPerTile>
static auto MakeDescriptor_X(index_t X)
{
const auto grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(NBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(I1, M));
}
}();
return PadTensorDescriptor(
grid_desc_m_n,
make_tuple(LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)),
Sequence<true, true>{});
const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X));
return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
}
static auto MakeDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_mraw;
}
};
static auto MakeDescriptor_N(index_t NRaw)
{
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N
return transform_tensor_descriptor(grid_desc_nraw,
make_tuple(make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_nraw;
}
};
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock =
decltype(MakeGemmMeanVarCountGridDescriptor_M_NBlock<HLayout>(1, 1));
decltype(MakeDescriptor_X_Y<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeLayernormMeanVarCountGridDescriptor_M_NBlock<HLayout>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
decltype(MakeDescriptor_X_Y<HLayout,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
......@@ -526,8 +467,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
gemm_mean_var_count_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
gamma_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
......@@ -545,11 +488,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_{epsilon}
{
gemm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeGemmMeanVarCountGridDescriptor_M_NBlock<HLayout>(MRaw, gemm_nblock_);
DeviceOp::MakeDescriptor_X_Y<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(
MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeLayernormMeanVarCountGridDescriptor_M_NBlock<HLayout>(MRaw,
gemm_nblock_);
DeviceOp::MakeDescriptor_X_Y<HLayout,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment