Commit c83bad81 authored by rocking's avatar rocking
Browse files

Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel

parent 90546fbe
...@@ -142,7 +142,7 @@ __global__ void ...@@ -142,7 +142,7 @@ __global__ void
const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
index_t numNormBlockTileIteration_N, index_t NBlockClusterLength,
ComputeDataType epsilon, ComputeDataType epsilon,
HElementwiseOperation h_element_op) HElementwiseOperation h_element_op)
{ {
...@@ -160,7 +160,7 @@ __global__ void ...@@ -160,7 +160,7 @@ __global__ void
gamma_grid_desc_n, gamma_grid_desc_n,
beta_grid_desc_n, beta_grid_desc_n,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N, NBlockClusterLength,
epsilon, epsilon,
h_element_op); h_element_op);
} }
...@@ -557,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -557,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEHGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
}); });
// populate desc for Ds/E/F/G // populate desc for Ds/E/mean/var/count
if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_, if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
...@@ -736,14 +736,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -736,14 +736,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.block_2_etile_map_, arg.block_2_etile_map_,
arg.NRaw_); arg.NRaw_);
grid_size = math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0)); index_t MBlockClusterLength =
math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
index_t NBlockClusterLength =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1));
grid_size = MBlockClusterLength * NBlockClusterLength;
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil( index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1)); arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel_welford_layernorm, kernel_welford_layernorm,
...@@ -764,7 +765,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -764,7 +765,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.gamma_grid_desc_n_, arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N, NBlockClusterLength,
arg.epsilon_, arg.epsilon_,
arg.h_element_op_); arg.h_element_op_);
......
...@@ -101,13 +101,16 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -101,13 +101,16 @@ struct GridwiseWelfordSecondHalfLayernorm2d
const GammaBetaGridDesc_N& gamma_grid_desc_n, const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n, const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
index_t numNormBlockTileIteration_N, index_t NBlockClusterLength,
ComputeDataType epsilon, ComputeDataType epsilon,
HElementwiseOperation h_element_op) HElementwiseOperation h_element_op)
{ {
// Thread/Block id // Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
block_global_id % NBlockClusterLength);
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
...@@ -152,22 +155,22 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -152,22 +155,22 @@ struct GridwiseWelfordSecondHalfLayernorm2d
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType, ComputeDataType,
MThreadSliceSize * ESrcVectorSize, MThreadSliceSize * NThreadSliceSize,
true> true>
e_thread_buf; e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType, ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize, MThreadSliceSize * NThreadSliceSize,
true> true>
gamma_thread_buf; gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType, ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize, MThreadSliceSize * NThreadSliceSize,
true> true>
beta_thread_buf; beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType, ComputeDataType,
MThreadSliceSize * HDstVectorSize, MThreadSliceSize * NThreadSliceSize,
true> true>
h_thread_buf; h_thread_buf;
...@@ -184,7 +187,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -184,7 +187,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
true>( true>(
mean_var_grid_desc_m_n, mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -200,7 +203,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -200,7 +203,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
true>( true>(
mean_var_grid_desc_m_n, mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -216,7 +219,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -216,7 +219,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
true>( true>(
count_grid_desc_m_n, count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -232,9 +235,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -232,9 +235,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
true>( true>(
e_grid_desc_m_n, e_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(
thread_m_cluster_id * MThreadSliceSize, block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize)); block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
auto threadwise_gamma_load_m_n = auto threadwise_gamma_load_m_n =
ThreadwiseTensorSliceTransfer_v2<GammaDataType, ThreadwiseTensorSliceTransfer_v2<GammaDataType,
...@@ -247,7 +250,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -247,7 +250,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
GammaSrcVectorSize, GammaSrcVectorSize,
1, 1,
true>( true>(
gamma_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize)); gamma_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_beta_load_m_n = auto threadwise_beta_load_m_n =
ThreadwiseTensorSliceTransfer_v2<BetaDataType, ThreadwiseTensorSliceTransfer_v2<BetaDataType,
...@@ -260,7 +265,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -260,7 +265,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
BetaSrcVectorSize, BetaSrcVectorSize,
1, 1,
true>( true>(
beta_grid_desc_n, make_multi_index(thread_n_cluster_id * NThreadSliceSize)); beta_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_h_store_m_n = auto threadwise_h_store_m_n =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
...@@ -276,13 +283,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -276,13 +283,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
true>( true>(
h_grid_desc_m_n, h_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(
thread_m_cluster_id * MThreadSliceSize, block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * NThreadSliceSize), block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
h_element_op); h_element_op);
// step1: Merge mean and variance // step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_m_n = constexpr auto mean_var_count_thread_copy_step_0_n =
make_multi_index(0, NThreadClusterSize); make_multi_index(0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -320,11 +327,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -320,11 +327,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_count_thread_buf); welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n, threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n, threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n, threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_0_n);
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -336,8 +343,6 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -336,8 +343,6 @@ struct GridwiseWelfordSecondHalfLayernorm2d
}); });
// step2: normalization // step2: normalization
for(index_t reducedTiles = 0; reducedTiles < numNormBlockTileIteration_N; ++reducedTiles)
{
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n] // h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n, threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf, e_global_val_buf,
...@@ -386,16 +391,6 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -386,16 +391,6 @@ struct GridwiseWelfordSecondHalfLayernorm2d
h_grid_desc_m_n, h_grid_desc_m_n,
h_global_val_buf); h_global_val_buf);
threadwise_e_load_m_n.MoveSrcSliceWindow(e_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
threadwise_gamma_load_m_n.MoveSrcSliceWindow(gamma_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_beta_load_m_n.MoveSrcSliceWindow(beta_grid_desc_n,
make_multi_index(N_BlockTileSize));
threadwise_h_store_m_n.MoveDstSliceWindow(h_grid_desc_m_n,
make_multi_index(0, N_BlockTileSize));
}
} // run } // run
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment