Commit 78ff5f81 authored by rocking's avatar rocking
Browse files

Implement layernorm

parent a4e34d88
...@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm, ...@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename EHGridDesc_M_N, typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N, typename MeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N> typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -127,13 +128,13 @@ __global__ void ...@@ -127,13 +128,13 @@ __global__ void
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n, const MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t blkgroup_size, index_t numMeanVarCountBlockTileIteration_N,
index_t num_mean_var_count_k_block_tile_iteration, index_t numNormBlockTileIteration_N,
index_t num_xy_k_block_tile_iteration, ComputeDataType epsilon,
ComputeDataType epsilon) HElementwiseOperation h_element_op)
{ {
GridwiseWelfordLayernorm::Run(p_e_grid, GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid, p_in_welford_mean_grid,
...@@ -144,13 +145,13 @@ __global__ void ...@@ -144,13 +145,13 @@ __global__ void
p_h_grid, p_h_grid,
e_grid_desc_m_n, e_grid_desc_m_n,
h_grid_desc_m_n, h_grid_desc_m_n,
mean_var_count_grid_desc_m_n, mean_var_count_grid_desc_m_nblock,
gamma_grid_desc_n, gamma_grid_desc_n,
beta_grid_desc_n, beta_grid_desc_n,
blkgroup_size, numMeanVarCountBlockTileIteration_N,
num_mean_var_count_k_block_tile_iteration, numNormBlockTileIteration_N,
num_xy_k_block_tile_iteration, epsilon,
epsilon); h_element_op);
} }
} // namespace ck } // namespace ck
...@@ -371,12 +372,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -371,12 +372,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
} }
}; };
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_N = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1)); using MeanVarCountGridDesc_M_NBlock = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N, GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize, BlockSize,
LayernormThreadClusterSize_M_N::At(I0), LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1), LayernormThreadClusterSize_M_N::At(I1),
...@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_n_{}, mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
...@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
mean_var_count_grid_desc_m_n_ = mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw)); hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
...@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_ = mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_); mean_var_count_grid_desc_m_nblock_);
} }
} }
...@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_; EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_; MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_; GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_; GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_; EHGridDesc_M_N h_grid_desc_m_n_;
...@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N>; GammaBetaGridDesc_N,
HElementwiseOperation>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_, arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
grid_size = math::integer_least_multiple(M, LayernormBlockTileSize_M_N::At(0)) / grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
LayernormBlockTileSize_M_N::At(0);
index_t numMeanVarCountBlockTileIteration_N = index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
math::integer_least_multiple(arg.gemm_nblock_, arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1);
index_t numEBlockTileIteration_N = index_t numNormBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) / math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
LayernormBlockTileSize_M_N::At(I1);
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm, kernel_welford_layernorm,
...@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_, arg.p_h_grid_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_, arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_, arg.mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_, arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
numEBlockTileIteration_N, numNormBlockTileIteration_N,
arg.epsilon_); arg.epsilon_,
arg.h_element_op_);
return avg_time; return avg_time;
}; };
......
...@@ -47,7 +47,7 @@ template <typename ABDataType, ...@@ -47,7 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
typename MeanVarCountGridDesc_M_N, typename MeanVarCountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>; MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment