Commit 7fd01e87 authored by wangshaojie6's avatar wangshaojie6
Browse files

1. coalesce load/store data for gridwise layer norm welford. 2. move a sqrt...

1. coalesce load/store data for gridwise layer norm welford. 2. move a sqrt and divison into a outer static loop
parent d7bb21c2
......@@ -93,10 +93,13 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
if(kPerBlockTail > 0)
{
int thread_max_len = (thread_k_cluster_id + 1) * KThreadSliceSize;
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
int thread_max_len =
(thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, KThreadSliceSize);
kPerThread += KThreadSliceSize - delta;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
kPerThread += XSrcVectorSize - delta;
});
}
return kPerThread;
......@@ -162,22 +165,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
},
Number<YThreadBufferNumber>{});
#if 0
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& beta_thread_buf = gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
#endif
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
......@@ -190,11 +177,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
// constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
// auto red_num = slice/vector;
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
......@@ -212,7 +194,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
thread_k_cluster_id * XSrcVectorSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
......@@ -228,7 +210,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
thread_k_cluster_id * GammaSrcVectorSize));
auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
......@@ -244,7 +226,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
thread_k_cluster_id * BetaSrcVectorSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
......@@ -262,16 +244,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
thread_k_cluster_id * YDstVectorSize),
acc_elementwise_op);
// Copy x from Cache
// one pass: fwd, second pass: bwd
// constexpr auto thread_copy_fwd_step_m_k =
// make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileStepSize);
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
......@@ -305,7 +281,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
}
#if 1
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
......@@ -313,9 +288,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});
#endif
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
......@@ -336,7 +311,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
}
#if 1
static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
......@@ -344,30 +318,21 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
make_tuple(I0, I0),
gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
#endif
#if 0
static_for<0, gamma_thread_buf.Size(), 1>{}([&](auto i){
gamma_thread_buf(i) = 1;
beta_thread_buf(i) = 1;
});
#endif
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / sqrt(var_thread_buf(iM) + epsilon);
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
#if 1
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
#endif
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......@@ -376,16 +341,16 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
});
});
#if 1
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
#endif
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
......@@ -411,8 +376,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment