Commit 90546fbe authored by rocking's avatar rocking
Browse files

We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1

parent 119cb7b1
...@@ -523,12 +523,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -523,12 +523,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{static_cast<AccDataType>(epsilon)} epsilon_{static_cast<AccDataType>(epsilon)}
{ {
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
gemm_mean_var_grid_desc_m_nblock_ = gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>( DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
gemm_count_grid_desc_m_nblock_ = gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>( DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ = layernorm_mean_var_grid_desc_m_nblock_ =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment