Commit 78ff5f81 authored by rocking's avatar rocking
Browse files

Implement layernorm

parent a4e34d88
......@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename GammaBetaGridDesc_N>
typename MeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -127,13 +128,13 @@ __global__ void
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t blkgroup_size,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
ComputeDataType epsilon)
index_t numMeanVarCountBlockTileIteration_N,
index_t numNormBlockTileIteration_N,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid,
......@@ -144,13 +145,13 @@ __global__ void
p_h_grid,
e_grid_desc_m_n,
h_grid_desc_m_n,
mean_var_count_grid_desc_m_n,
mean_var_count_grid_desc_m_nblock,
gamma_grid_desc_n,
beta_grid_desc_n,
blkgroup_size,
num_mean_var_count_k_block_tile_iteration,
num_xy_k_block_tile_iteration,
epsilon);
numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N,
epsilon,
h_element_op);
}
} // namespace ck
......@@ -371,12 +372,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
};
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_N = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_NBlock = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize,
LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1),
......@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_n_{},
mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
......@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
mean_var_count_grid_desc_m_n_ =
mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
......@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_);
mean_var_count_grid_desc_m_nblock_);
}
}
......@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_;
MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_;
......@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
GammaBetaGridDesc_N>;
MeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
grid_size = math::integer_least_multiple(M, LayernormBlockTileSize_M_N::At(0)) /
LayernormBlockTileSize_M_N::At(0);
grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
index_t numMeanVarCountBlockTileIteration_N =
math::integer_least_multiple(arg.gemm_nblock_,
LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1);
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
index_t numEBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) /
LayernormBlockTileSize_M_N::At(I1);
index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
......@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N,
numEBlockTileIteration_N,
arg.epsilon_);
numNormBlockTileIteration_N,
arg.epsilon_,
arg.h_element_op_);
return avg_time;
};
......
......@@ -47,7 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......
......@@ -27,8 +27,9 @@ template <typename EDataType,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation,
index_t BlockSize,
index_t MThreadClusterSize,
index_t NThreadClusterSize,
......@@ -42,32 +43,34 @@ template <typename EDataType,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % ESrcVectorSize == 0) ||
(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0),
// TODO - Support ESrcHDstVectorDim == 0
static_assert(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0 &&
NThreadSliceSize % GammaSrcVectorSize == 0 &&
NThreadSliceSize % BetaSrcVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % HDstVectorSize == 0) ||
(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0),
static_assert(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (ESrcHDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder = Sequence<0, 1>;
using ThreadClusterArrangeOrder = Sequence<0, 1>;
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
static constexpr auto thread_cluster_desc_m_n =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
using ThreadBufferLengths_N = Sequence<NThreadSliceSize>;
static constexpr auto thread_buffer_desc_n =
make_naive_tensor_descriptor_packed(make_tuple(Number<NThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......@@ -80,19 +83,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ThreadClusterLengths_M_N,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
static constexpr index_t N_BlockTileStepSize = NThreadClusterSize * ESrcVectorSize;
static constexpr auto EThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto HThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
......@@ -103,47 +98,88 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t gemm_nblock_,
index_t numMeanVarCountBlockTileIteration_N,
index_t numEBlockTileIteration_N,
ComputeDataType epsilon)
index_t numNormBlockTileIteration_N,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
ignore = p_e_grid;
ignore = p_in_welford_mean_grid;
ignore = p_in_welford_var_grid;
ignore = p_in_welford_count_grid;
ignore = p_gamma_grid;
ignore = p_beta_grid;
ignore = p_h_grid;
ignore = e_grid_desc_m_n;
ignore = h_grid_desc_m_n;
ignore = mean_var_count_grid_desc_m_n;
ignore = gamma_grid_desc_n;
ignore = beta_grid_desc_n;
ignore = gemm_nblock_;
ignore = numMeanVarCountBlockTileIteration_N;
ignore = numEBlockTileIteration_N;
ignore = epsilon;
// Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
// step1: Merge mean and variance
auto threadwise_mean_load_m_k =
// Global Memory
const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * ESrcVectorSize,
true>
e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * HDstVectorSize,
true>
h_thread_buf;
// IO
auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
ThreadBufferDimAccessOrder,
1,
1,
1,
......@@ -153,13 +189,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_var_load_m_k =
auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<VarDataType,
ComputeDataType,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
ThreadBufferDimAccessOrder,
1,
1,
1,
......@@ -169,13 +205,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_count_load_m_k =
auto threadwise_count_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
ThreadBufferDimAccessOrder,
1,
1,
1,
......@@ -185,29 +221,68 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
auto threadwise_e_load_m_n =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
decltype(e_grid_desc_m_n),
decltype(thread_buffer_desc_m_n),
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
ESrcHDstVectorDim,
ESrcVectorSize,
1,
true>(
e_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
auto threadwise_gamma_load_m_n =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
ComputeDataType,
decltype(gamma_grid_desc_n),
decltype(thread_buffer_desc_n),
ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
auto threadwise_beta_load_m_n =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
ComputeDataType,
decltype(beta_grid_desc_n),
decltype(thread_buffer_desc_n),
ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize));
auto threadwise_h_store_m_n =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
HDataType,
decltype(thread_buffer_desc_m_n),
decltype(h_grid_desc_m_n),
HElementwiseOperation,
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
ESrcHDstVectorDim,
HDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
h_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize),
h_element_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_m_n =
make_multi_index(0, NThreadClusterSize);
......@@ -220,23 +295,23 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
++reducedTiles)
{
threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
threadwise_mean_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
......@@ -245,12 +320,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -262,9 +337,64 @@ struct GridwiseWelfordSecondHalfLayernorm2d
});
// step2: normalization
for(index_t reducedTiles = 0; reducedTiles < numEBlockTileIteration_N; ++reducedTiles)
for(index_t reducedTiles = 0; reducedTiles < numNormBlockTileIteration_N; ++reducedTiles)
{
// TODO
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf,
thread_buffer_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(welford_var_thread_buf(m) + epsilon);
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) =
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
});
});
threadwise_gamma_load_m_n.Run(gamma_grid_desc_n,
gamma_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
});
threadwise_beta_load_m_n.Run(beta_grid_desc_n,
beta_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
});
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
make_tuple(I0, I0),
h_thread_buf,
h_grid_desc_m_n,
h_global_val_buf);
threadwise_e_load_m_n.MoveSrcSliceWindow(e_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
threadwise_gamma_load_m_n.MoveSrcSliceWindow(gamma_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_beta_load_m_n.MoveSrcSliceWindow(beta_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_h_store_m_n.MoveDstSliceWindow(h_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
}
} // run
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment