Commit 7fd01e87 authored by wangshaojie6's avatar wangshaojie6
Browse files

1. coalesce load/store data for gridwise layer norm welford. 2. move a sqrt...

1. coalesce load/store data for gridwise layer norm welford. 2. move a sqrt and divison into a outer static loop
parent d7bb21c2
...@@ -93,10 +93,13 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -93,10 +93,13 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
int thread_max_len = (thread_k_cluster_id + 1) * KThreadSliceSize; static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
int delta = thread_max_len - kPerBlockTail; int thread_max_len =
delta = math::clamp(thread_max_len - kPerBlockTail, 0, KThreadSliceSize); (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
kPerThread += KThreadSliceSize - delta; int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
kPerThread += XSrcVectorSize - delta;
});
} }
return kPerThread; return kPerThread;
...@@ -162,22 +165,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -162,22 +165,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}, },
Number<YThreadBufferNumber>{}); Number<YThreadBufferNumber>{});
#if 0
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& beta_thread_buf = gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
#endif
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
...@@ -190,11 +177,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -190,11 +177,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
// using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
// constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
// auto red_num = slice/vector;
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>; using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
...@@ -212,7 +194,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -212,7 +194,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_grid_desc_m_k, x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * XSrcVectorSize));
auto threadwise_gamma_load = auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType, ThreadwiseTensorSliceTransfer_v2<GammaDataType,
...@@ -228,7 +210,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -228,7 +210,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * GammaSrcVectorSize));
auto threadwise_beta_load = auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType, ThreadwiseTensorSliceTransfer_v2<BetaDataType,
...@@ -244,7 +226,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -244,7 +226,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
beta_grid_desc_m_k, beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * BetaSrcVectorSize));
auto threadwise_y_store = auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
...@@ -262,16 +244,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -262,16 +244,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_grid_desc_m_k, y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize), thread_k_cluster_id * YDstVectorSize),
acc_elementwise_op); acc_elementwise_op);
// Copy x from Cache constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
// one pass: fwd, second pass: bwd
// constexpr auto thread_copy_fwd_step_m_k =
// make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -305,7 +281,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -305,7 +281,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}); });
} }
#if 1
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0) if constexpr(I > 0)
block_sync_lds(); block_sync_lds();
...@@ -313,9 +288,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -313,9 +288,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
}); });
#endif
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
...@@ -336,7 +311,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -336,7 +311,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}); });
} }
#if 1
static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k, threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf, gamma_global_val_buf,
...@@ -344,30 +318,21 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -344,30 +318,21 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
make_tuple(I0, I0), make_tuple(I0, I0),
gamma_thread_buf(i)); gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
}); });
#endif
#if 0
static_for<0, gamma_thread_buf.Size(), 1>{}([&](auto i){
gamma_thread_buf(i) = 1;
beta_thread_buf(i) = 1;
});
#endif
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / sqrt(var_thread_buf(iM) + epsilon); auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
#if 1
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; divisor;
#endif
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -376,16 +341,16 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -376,16 +341,16 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}); });
}); });
}); });
#if 1
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k, threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf, beta_global_val_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
beta_thread_buf(i)); beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
}); });
#endif
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
...@@ -411,8 +376,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -411,8 +376,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}); });
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); 2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment