Commit c83bad81 authored by rocking's avatar rocking
Browse files

Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel

parent 90546fbe
......@@ -142,7 +142,7 @@ __global__ void
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t numNormBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
......@@ -160,7 +160,7 @@ __global__ void
gamma_grid_desc_n,
beta_grid_desc_n,
numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N,
NBlockClusterLength,
epsilon,
h_element_op);
}
......@@ -557,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E/F/G
// populate desc for Ds/E/mean/var/count
if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
......@@ -736,14 +736,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.block_2_etile_map_,
arg.NRaw_);
grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
index_t MBlockClusterLength =
math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
index_t NBlockClusterLength =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1));
grid_size = MBlockClusterLength * NBlockClusterLength;
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time +=
launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
......@@ -764,7 +765,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N,
NBlockClusterLength,
arg.epsilon_,
arg.h_element_op_);
......
......@@ -101,13 +101,16 @@ struct GridwiseWelfordSecondHalfLayernorm2d
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t numNormBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
// Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
block_global_id % NBlockClusterLength);
const auto thread_cluster_idx =
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
......@@ -152,22 +155,22 @@ struct GridwiseWelfordSecondHalfLayernorm2d
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * ESrcVectorSize,
MThreadSliceSize * NThreadSliceSize,
true>
e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize,
MThreadSliceSize * NThreadSliceSize,
true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize,
MThreadSliceSize * NThreadSliceSize,
true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * HDstVectorSize,
MThreadSliceSize * NThreadSliceSize,
true>
h_thread_buf;
......@@ -184,7 +187,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
true>(
mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -200,7 +203,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
true>(
mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -216,7 +219,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
true>(
count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -232,9 +235,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
true>(
e_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize));
make_multi_index(
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
auto threadwise_gamma_load_m_n =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
......@@ -247,7 +250,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize));
gamma_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_beta_load_m_n =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
......@@ -260,7 +265,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize));
beta_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_h_store_m_n =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
......@@ -276,13 +283,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
true>(
h_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize),
make_multi_index(
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
h_element_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_m_n =
constexpr auto mean_var_count_thread_copy_step_0_n =
make_multi_index(0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -320,11 +327,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
mean_var_count_thread_copy_step_0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
mean_var_count_thread_copy_step_0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
mean_var_count_thread_copy_step_0_n);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -336,66 +343,54 @@ struct GridwiseWelfordSecondHalfLayernorm2d
});
// step2: normalization
for(index_t reducedTiles = 0; reducedTiles < numNormBlockTileIteration_N; ++reducedTiles)
{
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf,
thread_buffer_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(welford_var_thread_buf(m) + epsilon);
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) =
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
});
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf,
thread_buffer_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(welford_var_thread_buf(m) + epsilon);
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) =
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
});
});
threadwise_gamma_load_m_n.Run(gamma_grid_desc_n,
gamma_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
gamma_thread_buf);
threadwise_gamma_load_m_n.Run(gamma_grid_desc_n,
gamma_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
});
threadwise_beta_load_m_n.Run(beta_grid_desc_n,
beta_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
beta_thread_buf);
threadwise_beta_load_m_n.Run(beta_grid_desc_n,
beta_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
});
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
make_tuple(I0, I0),
h_thread_buf,
h_grid_desc_m_n,
h_global_val_buf);
threadwise_e_load_m_n.MoveSrcSliceWindow(e_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
threadwise_gamma_load_m_n.MoveSrcSliceWindow(gamma_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_beta_load_m_n.MoveSrcSliceWindow(beta_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_h_store_m_n.MoveDstSliceWindow(h_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
}
} // run
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment