Commit 78ff5f81 authored by rocking's avatar rocking
Browse files

Implement layernorm

parent a4e34d88
...@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm, ...@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename EHGridDesc_M_N, typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N, typename MeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N> typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -127,13 +128,13 @@ __global__ void ...@@ -127,13 +128,13 @@ __global__ void
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n, const MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t blkgroup_size, index_t numMeanVarCountBlockTileIteration_N,
index_t num_mean_var_count_k_block_tile_iteration, index_t numNormBlockTileIteration_N,
index_t num_xy_k_block_tile_iteration, ComputeDataType epsilon,
ComputeDataType epsilon) HElementwiseOperation h_element_op)
{ {
GridwiseWelfordLayernorm::Run(p_e_grid, GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid, p_in_welford_mean_grid,
...@@ -144,13 +145,13 @@ __global__ void ...@@ -144,13 +145,13 @@ __global__ void
p_h_grid, p_h_grid,
e_grid_desc_m_n, e_grid_desc_m_n,
h_grid_desc_m_n, h_grid_desc_m_n,
mean_var_count_grid_desc_m_n, mean_var_count_grid_desc_m_nblock,
gamma_grid_desc_n, gamma_grid_desc_n,
beta_grid_desc_n, beta_grid_desc_n,
blkgroup_size, numMeanVarCountBlockTileIteration_N,
num_mean_var_count_k_block_tile_iteration, numNormBlockTileIteration_N,
num_xy_k_block_tile_iteration, epsilon,
epsilon); h_element_op);
} }
} // namespace ck } // namespace ck
...@@ -374,7 +375,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -374,7 +375,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_N = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1)); using MeanVarCountGridDesc_M_NBlock = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
...@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N, GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize, BlockSize,
LayernormThreadClusterSize_M_N::At(I0), LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1), LayernormThreadClusterSize_M_N::At(I1),
...@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_n_{}, mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
...@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
mean_var_count_grid_desc_m_n_ = mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw)); hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
...@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_ = mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_); mean_var_count_grid_desc_m_nblock_);
} }
} }
...@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_; EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_; MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_; GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_; GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_; EHGridDesc_M_N h_grid_desc_m_n_;
...@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N>; GammaBetaGridDesc_N,
HElementwiseOperation>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_, arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
grid_size = math::integer_least_multiple(M, LayernormBlockTileSize_M_N::At(0)) / grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
LayernormBlockTileSize_M_N::At(0);
index_t numMeanVarCountBlockTileIteration_N = index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
math::integer_least_multiple(arg.gemm_nblock_, arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1);
index_t numEBlockTileIteration_N = index_t numNormBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) / math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
LayernormBlockTileSize_M_N::At(I1);
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm, kernel_welford_layernorm,
...@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_, arg.p_h_grid_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_, arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_, arg.mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_, arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
numEBlockTileIteration_N, numNormBlockTileIteration_N,
arg.epsilon_); arg.epsilon_,
arg.h_element_op_);
return avg_time; return avg_time;
}; };
......
...@@ -47,7 +47,7 @@ template <typename ABDataType, ...@@ -47,7 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
typename MeanVarCountGridDesc_M_N, typename MeanVarCountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>; MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......
...@@ -27,8 +27,9 @@ template <typename EDataType, ...@@ -27,8 +27,9 @@ template <typename EDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename EHGridDesc_M_N, typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N, typename MeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N, typename GammaBetaGridDesc_N,
typename HElementwiseOperation,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t NThreadClusterSize, index_t NThreadClusterSize,
...@@ -42,32 +43,34 @@ template <typename EDataType, ...@@ -42,32 +43,34 @@ template <typename EDataType,
index_t MeanVarSrcDstVectorSize> index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d struct GridwiseWelfordSecondHalfLayernorm2d
{ {
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % ESrcVectorSize == 0) || // TODO - Support ESrcHDstVectorDim == 0
(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0), static_assert(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0 &&
NThreadSliceSize % GammaSrcVectorSize == 0 &&
NThreadSliceSize % BetaSrcVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % HDstVectorSize == 0) || static_assert(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0,
(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (ESrcHDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>; using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder = Sequence<0, 1>;
using ThreadClusterArrangeOrder = Sequence<0, 1>;
using ThreadBufferDimAccessOrder = static constexpr auto thread_cluster_desc_m_n =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>; using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 = static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
using ThreadBufferLengths_N = Sequence<NThreadSliceSize>;
static constexpr auto thread_buffer_desc_n =
make_naive_tensor_descriptor_packed(make_tuple(Number<NThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); using ThreadReduceSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
...@@ -80,19 +83,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -80,19 +83,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ThreadClusterLengths_M_N, ThreadClusterLengths_M_N,
ThreadClusterArrangeOrder>; ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize; static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
static constexpr index_t N_BlockTileStepSize = NThreadClusterSize * ESrcVectorSize;
static constexpr auto EThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto HThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
__device__ static void Run(const EDataType* __restrict__ p_e_grid, __device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid, const MeanDataType* __restrict__ p_in_welford_mean_grid,
...@@ -103,47 +98,88 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -103,47 +98,88 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n, const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n, const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n, const MeanVarCountGridDesc_M_NBlock& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n, const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n, const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t gemm_nblock_,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
index_t numEBlockTileIteration_N, index_t numNormBlockTileIteration_N,
ComputeDataType epsilon) ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{ {
ignore = p_e_grid;
ignore = p_in_welford_mean_grid;
ignore = p_in_welford_var_grid;
ignore = p_in_welford_count_grid;
ignore = p_gamma_grid;
ignore = p_beta_grid;
ignore = p_h_grid;
ignore = e_grid_desc_m_n;
ignore = h_grid_desc_m_n;
ignore = mean_var_count_grid_desc_m_n;
ignore = gamma_grid_desc_n;
ignore = beta_grid_desc_n;
ignore = gemm_nblock_;
ignore = numMeanVarCountBlockTileIteration_N;
ignore = numEBlockTileIteration_N;
ignore = epsilon;
// Thread/Block id // Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1]; const auto thread_n_cluster_id = thread_cluster_idx[I1];
// step1: Merge mean and variance // Global Memory
auto threadwise_mean_load_m_k = const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * ESrcVectorSize,
true>
e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * HDstVectorSize,
true>
h_thread_buf;
// IO
auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<MeanDataType, ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType, ComputeDataType,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, ThreadBufferDimAccessOrder,
1, 1,
1, 1,
1, 1,
...@@ -153,13 +189,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -153,13 +189,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
auto threadwise_var_load_m_k = auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<VarDataType, ThreadwiseTensorSliceTransfer_v2<VarDataType,
ComputeDataType, ComputeDataType,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, ThreadBufferDimAccessOrder,
1, 1,
1, 1,
1, 1,
...@@ -169,13 +205,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -169,13 +205,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
auto threadwise_count_load_m_k = auto threadwise_count_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<int32_t, ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t, int32_t,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, ThreadBufferDimAccessOrder,
1, 1,
1, 1,
1, 1,
...@@ -185,29 +221,68 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -185,29 +221,68 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto threadwise_e_load_m_n =
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( decltype(e_grid_desc_m_n),
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); decltype(thread_buffer_desc_m_n),
ThreadBufferLengths_M_N,
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( ThreadBufferDimAccessOrder,
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); ESrcHDstVectorDim,
ESrcVectorSize,
1,
true>(
e_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> auto threadwise_gamma_load_m_n =
in_welford_mean_thread_buf; ThreadwiseTensorSliceTransfer_v2<GammaDataType,
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> ComputeDataType,
in_welford_var_thread_buf; decltype(gamma_grid_desc_n),
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> decltype(thread_buffer_desc_n),
in_welford_count_thread_buf; ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> auto threadwise_beta_load_m_n =
welford_mean_thread_buf; ThreadwiseTensorSliceTransfer_v2<BetaDataType,
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> ComputeDataType,
welford_var_thread_buf; decltype(beta_grid_desc_n),
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> decltype(thread_buffer_desc_n),
welford_count_thread_buf; ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize));
auto threadwise_h_store_m_n =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
HDataType,
decltype(thread_buffer_desc_m_n),
decltype(h_grid_desc_m_n),
HElementwiseOperation,
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
ESrcHDstVectorDim,
HDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
h_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize),
h_element_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_m_n = constexpr auto mean_var_count_thread_copy_step_m_n =
make_multi_index(0, NThreadClusterSize); make_multi_index(0, NThreadClusterSize);
...@@ -220,19 +295,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -220,19 +295,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N; for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
++reducedTiles) ++reducedTiles)
{ {
threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n, threadwise_mean_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
welford_mean_global_val_buf, welford_mean_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_mean_thread_buf); in_welford_mean_thread_buf);
threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n, threadwise_var_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
welford_var_global_val_buf, welford_var_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_var_thread_buf); in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n, threadwise_count_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
welford_count_global_val_buf, welford_count_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
...@@ -245,11 +320,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -245,11 +320,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n, threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_m_n);
threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n, threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_m_n);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n, threadwise_count_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_m_n);
} }
...@@ -262,9 +337,64 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -262,9 +337,64 @@ struct GridwiseWelfordSecondHalfLayernorm2d
}); });
// step2: normalization // step2: normalization
for(index_t reducedTiles = 0; reducedTiles < numEBlockTileIteration_N; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < numNormBlockTileIteration_N; ++reducedTiles)
{ {
// TODO // h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf,
thread_buffer_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(welford_var_thread_buf(m) + epsilon);
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) =
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
});
});
threadwise_gamma_load_m_n.Run(gamma_grid_desc_n,
gamma_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
});
threadwise_beta_load_m_n.Run(beta_grid_desc_n,
beta_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
});
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
make_tuple(I0, I0),
h_thread_buf,
h_grid_desc_m_n,
h_global_val_buf);
threadwise_e_load_m_n.MoveSrcSliceWindow(e_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
threadwise_gamma_load_m_n.MoveSrcSliceWindow(gamma_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_beta_load_m_n.MoveSrcSliceWindow(beta_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_h_store_m_n.MoveDstSliceWindow(h_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
} }
} // run } // run
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment