Commit c3107fd5 authored by rocking's avatar rocking
Browse files

Extract var

parent 5aa3c344
...@@ -340,22 +340,22 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -340,22 +340,22 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
f_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>( f_grid_desc_m_n_{},
MRaw, g_grid_desc_m_n_{},
math::integer_divide_ceil(NRaw, NPerBlock),
math::integer_divide_ceil(NRaw, NPerBlock))},
g_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(
MRaw,
math::integer_divide_ceil(NRaw, NPerBlock),
math::integer_divide_ceil(NRaw, NPerBlock))},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
h_element_op_{h_element_op} h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}
{ {
int welford_size = MRaw * math::integer_divide_ceil(NRaw, NPerBlock); f_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
g_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
int welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_f_grid_, sizeof(FDataType) * welford_size)); hip_check_error(hipMalloc(&p_f_grid_, sizeof(FDataType) * welford_size));
hip_check_error(hipMalloc(&p_g_grid_, sizeof(GDataType) * welford_size)); hip_check_error(hipMalloc(&p_g_grid_, sizeof(GDataType) * welford_size));
...@@ -448,6 +448,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -448,6 +448,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_; HElementwiseOperation h_element_op_;
int blkGroupSize_;
}; };
// Invoker // Invoker
......
...@@ -851,7 +851,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -851,7 +851,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
BlockwiseWelford<AccDataType, BlockwiseWelford<AccDataType,
BlockSize, BlockSize,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
Sequence<0, 1>>; Sequence<0, 1>,
false>;
constexpr int num_shuffleM = constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl); MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment