Commit 18e65656 authored by rocking's avatar rocking
Browse files

Refine naming

parent 44b66c41
...@@ -96,8 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -96,8 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n, const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n, const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_n, const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
const CountGridDesc_M_NBlock& count_grid_desc_m_n, const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
const GammaBetaGridDesc_N& gamma_grid_desc_n, const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n, const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
...@@ -121,13 +121,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -121,13 +121,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize()); p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_grid_desc_m_n.GetElementSpaceSize()); p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_grid_desc_m_n.GetElementSpaceSize()); p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, count_grid_desc_m_n.GetElementSpaceSize()); p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize()); p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
...@@ -186,7 +186,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -186,7 +186,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
1, 1,
true>( true>(
mean_var_grid_desc_m_n, mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize + make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -202,7 +202,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -202,7 +202,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
1, 1,
true>( true>(
mean_var_grid_desc_m_n, mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize + make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -218,7 +218,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -218,7 +218,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
1, 1,
true>( true>(
count_grid_desc_m_n, count_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize + make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -289,8 +289,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -289,8 +289,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
h_element_op); h_element_op);
// step1: Merge mean and variance // step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_0_n = constexpr auto mean_var_count_thread_copy_step_I0_n =
make_multi_index(0, NThreadClusterSize); make_multi_index(I0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f); welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
...@@ -300,19 +300,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -300,19 +300,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n) for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
{ {
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_n, threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_mean_global_val_buf, welford_mean_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_mean_thread_buf); in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_n, threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_var_global_val_buf, welford_var_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_var_thread_buf); in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(count_grid_desc_m_n, threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
welford_count_global_val_buf, welford_count_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
...@@ -325,12 +325,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -325,12 +325,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n, threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_0_n); mean_var_count_thread_copy_step_I0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n, threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_0_n); mean_var_count_thread_copy_step_I0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n, threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
mean_var_count_thread_copy_step_0_n); mean_var_count_thread_copy_step_I0_n);
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment