Commit eaeef340 authored by rocking's avatar rocking
Browse files

Add layout parameter

parent 27b19e34
...@@ -314,17 +314,37 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -314,17 +314,37 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
template <typename LayOut>
static auto MakeGemmMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock) static auto MakeGemmMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{ {
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock)); const auto grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(NBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(I1, M));
}
}();
return PadTensorDescriptor( return PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), Sequence<true, true>{}); grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), Sequence<true, true>{});
} }
template <typename LayOut>
static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock) static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{ {
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock)); const auto grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(NBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, LayOut>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, NBlock), make_tuple(I1, M));
}
}();
return PadTensorDescriptor( return PadTensorDescriptor(
grid_desc_m_n, grid_desc_m_n,
...@@ -388,9 +408,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -388,9 +408,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock = using GemmMeanVarCountGridDesc_M_NBlock =
decltype(MakeGemmMeanVarCountGridDescriptor_M_NBlock(1, 1)); decltype(MakeGemmMeanVarCountGridDescriptor_M_NBlock<HLayout>(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock = using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeLayernormMeanVarCountGridDescriptor_M_NBlock(1, 1)); decltype(MakeLayernormMeanVarCountGridDescriptor_M_NBlock<HLayout>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
...@@ -525,10 +545,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -525,10 +545,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_{epsilon} epsilon_{epsilon}
{ {
gemm_mean_var_count_grid_desc_m_nblock_ = gemm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeGemmMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeGemmMeanVarCountGridDescriptor_M_NBlock<HLayout>(MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ = layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeLayernormMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeLayernormMeanVarCountGridDescriptor_M_NBlock<HLayout>(MRaw,
gemm_nblock_);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment