Commit be14ab57 authored by rocking's avatar rocking
Browse files

padding for GemmMeanVarCountGridDescriptor_M_NBlock

parent 70e7069c
......@@ -318,9 +318,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
// TODO - padding according to MNperBlock of Gemm
// CAUSION - GetWorkSpaceSize
return grid_desc_m_n;
return PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), Sequence<true, true>{});
}
static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
......@@ -521,6 +520,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
h_element_op_{h_element_op},
MRaw_(MRaw),
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
......@@ -617,6 +617,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_;
int MRaw_;
int gemm_nblock_;
AccDataType epsilon_;
};
......@@ -766,9 +767,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
size_t workspace_size = 0;
// FIXME - padding
int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
......@@ -788,9 +787,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
pArg_->p_workspace_ = p_workspace;
int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment