Commit 18e65656 authored by rocking's avatar rocking
Browse files

Refine naming

parent 44b66c41
......@@ -96,8 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_n,
const CountGridDesc_M_NBlock& count_grid_desc_m_n,
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
......@@ -121,13 +121,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_grid_desc_m_n.GetElementSpaceSize());
p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_grid_desc_m_n.GetElementSpaceSize());
p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, count_grid_desc_m_n.GetElementSpaceSize());
p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
......@@ -186,7 +186,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
1,
true>(
mean_var_grid_desc_m_n,
mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -202,7 +202,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
1,
true>(
mean_var_grid_desc_m_n,
mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -218,7 +218,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
1,
true>(
count_grid_desc_m_n,
count_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -289,8 +289,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
h_element_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_0_n =
make_multi_index(0, NThreadClusterSize);
constexpr auto mean_var_count_thread_copy_step_I0_n =
make_multi_index(I0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
......@@ -300,19 +300,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
{
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_n,
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_n,
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(count_grid_desc_m_n,
threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
......@@ -325,12 +325,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n,
mean_var_count_thread_copy_step_0_n);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment