Commit c3107fd5 authored by rocking's avatar rocking
Browse files

Extract var

parent 5aa3c344
......@@ -340,22 +340,22 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
f_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(
MRaw,
math::integer_divide_ceil(NRaw, NPerBlock),
math::integer_divide_ceil(NRaw, NPerBlock))},
g_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(
MRaw,
math::integer_divide_ceil(NRaw, NPerBlock),
math::integer_divide_ceil(NRaw, NPerBlock))},
f_grid_desc_m_n_{},
g_grid_desc_m_n_{},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
h_element_op_{h_element_op}
h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}
{
int welford_size = MRaw * math::integer_divide_ceil(NRaw, NPerBlock);
f_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
g_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
int welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_f_grid_, sizeof(FDataType) * welford_size));
hip_check_error(hipMalloc(&p_g_grid_, sizeof(GDataType) * welford_size));
......@@ -448,6 +448,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_;
int blkGroupSize_;
};
// Invoker
......
......@@ -851,7 +851,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
BlockwiseWelford<AccDataType,
BlockSize,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
Sequence<0, 1>>;
Sequence<0, 1>,
false>;
constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment