Commit 90546fbe authored by rocking's avatar rocking
Browse files

We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1

parent 119cb7b1
......@@ -523,12 +523,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{static_cast<AccDataType>(epsilon)}
{
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(
MRaw, gemm_nblock_);
gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(
MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment