Commit be14ab57 authored by rocking's avatar rocking
Browse files

padding for GemmMeanVarCountGridDescriptor_M_NBlock

parent 70e7069c
...@@ -318,9 +318,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -318,9 +318,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock)); const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
// TODO - padding according to MNperBlock of Gemm return PadTensorDescriptor(
// CAUSION - GetWorkSpaceSize grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), Sequence<true, true>{});
return grid_desc_m_n;
} }
static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock) static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
...@@ -521,6 +520,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -521,6 +520,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
h_element_op_{h_element_op}, h_element_op_{h_element_op},
MRaw_(MRaw),
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
...@@ -617,6 +617,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -617,6 +617,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_; HElementwiseOperation h_element_op_;
int MRaw_;
int gemm_nblock_; int gemm_nblock_;
AccDataType epsilon_; AccDataType epsilon_;
}; };
...@@ -766,9 +767,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -766,9 +767,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
size_t workspace_size = 0; size_t workspace_size = 0;
// FIXME - padding int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// workspace for welford intermediate mean // workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64; workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
...@@ -788,9 +787,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -788,9 +787,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
pArg_->p_workspace_ = p_workspace; pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment