Commit 78ff5f81 authored by rocking's avatar rocking
Browse files

Implement layernorm

parent a4e34d88
......@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename GammaBetaGridDesc_N>
typename MeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -127,13 +128,13 @@ __global__ void
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t blkgroup_size,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
ComputeDataType epsilon)
index_t numMeanVarCountBlockTileIteration_N,
index_t numNormBlockTileIteration_N,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid,
......@@ -144,13 +145,13 @@ __global__ void
p_h_grid,
e_grid_desc_m_n,
h_grid_desc_m_n,
mean_var_count_grid_desc_m_n,
mean_var_count_grid_desc_m_nblock,
gamma_grid_desc_n,
beta_grid_desc_n,
blkgroup_size,
num_mean_var_count_k_block_tile_iteration,
num_xy_k_block_tile_iteration,
epsilon);
numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N,
epsilon,
h_element_op);
}
} // namespace ck
......@@ -374,7 +375,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_N = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using MeanVarCountGridDesc_M_NBlock = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
......@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize,
LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1),
......@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_n_{},
mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
......@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
mean_var_count_grid_desc_m_n_ =
mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
......@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_);
mean_var_count_grid_desc_m_nblock_);
}
}
......@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_;
MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_;
......@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
GammaBetaGridDesc_N>;
MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
grid_size = math::integer_least_multiple(M, LayernormBlockTileSize_M_N::At(0)) /
LayernormBlockTileSize_M_N::At(0);
grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
index_t numMeanVarCountBlockTileIteration_N =
math::integer_least_multiple(arg.gemm_nblock_,
LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1);
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
index_t numEBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) /
LayernormBlockTileSize_M_N::At(I1);
index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
......@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N,
numEBlockTileIteration_N,
arg.epsilon_);
numNormBlockTileIteration_N,
arg.epsilon_,
arg.h_element_op_);
return avg_time;
};
......
......@@ -47,7 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment