Commit d7bb21c2 authored by wangshaojie6's avatar wangshaojie6
Browse files

optimize group layer norm

parent 8daff431
...@@ -57,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -57,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
...@@ -73,8 +73,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -73,8 +73,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id) int thread_k_cluster_id)
...@@ -116,6 +122,47 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -116,6 +122,47 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<XThreadBufferNumber>{});
auto gamma_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>{};
},
Number<GammaThreadBufferNumber>{});
auto beta_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>{};
},
Number<BetaThreadBufferNumber>{});
auto y_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * YDstVectorSize,
true>{};
},
Number<YThreadBufferNumber>{});
#if 0
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf; x_thread_buf;
...@@ -129,6 +176,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -129,6 +176,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf; y_thread_buf;
#endif
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
...@@ -142,9 +190,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -142,9 +190,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>; // using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
// constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
// auto red_num = slice/vector;
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType, auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType, AccDataType,
...@@ -214,8 +267,11 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -214,8 +267,11 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache // Copy x from Cache
// one pass: fwd, second pass: bwd // one pass: fwd, second pass: bwd
// constexpr auto thread_copy_fwd_step_m_k =
// make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k = constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -238,16 +294,18 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -238,16 +294,18 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{ {
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_val_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
});
} }
#if 1
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0) if constexpr(I > 0)
block_sync_lds(); block_sync_lds();
...@@ -255,6 +313,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -255,6 +313,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
}); });
#endif
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
...@@ -267,62 +326,94 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -267,62 +326,94 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
{ {
if constexpr(!SweepOnce) if constexpr(!SweepOnce)
{ {
threadwise_x_load.Run(x_grid_desc_m_k, static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
x_global_val_buf, threadwise_x_load.Run(x_grid_desc_m_k,
thread_buffer_desc_m_k, x_global_val_buf,
make_tuple(I0, I0), thread_buffer_desc_m_k,
x_thread_buf); make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
} }
threadwise_gamma_load.Run(gamma_grid_desc_m_k, #if 1
gamma_global_val_buf, static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
thread_buffer_desc_m_k, threadwise_gamma_load.Run(gamma_grid_desc_m_k,
make_tuple(I0, I0), gamma_global_val_buf,
gamma_thread_buf); thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_fwd_step_m_k);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_thread_buf(iM) + epsilon);
// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
});
}); });
#endif
threadwise_beta_load.Run(beta_grid_desc_m_k, #if 0
beta_global_val_buf, static_for<0, gamma_thread_buf.Size(), 1>{}([&](auto i){
thread_buffer_desc_m_k, gamma_thread_buf(i) = 1;
make_tuple(I0, I0), beta_thread_buf(i) = 1;
beta_thread_buf); });
#endif
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { auto divisor = 1 / sqrt(var_thread_buf(iM) + epsilon);
constexpr auto offset_m_k = static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
#if 1
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
#endif
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
#if 1
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
#endif
// beta static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
y_thread_buf(Number<offset_m_k>{}) = static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{}); static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
}); });
}); });
threadwise_y_store.Run(thread_buffer_desc_m_k, static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
make_tuple(I0, I0), threadwise_y_store.Run(thread_buffer_desc_m_k,
y_thread_buf, make_tuple(I0, I0),
y_grid_desc_m_k, y_thread_buf(i),
y_global_val_buf); y_grid_desc_m_k,
y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment