Commit 1a38e362 authored by rocking's avatar rocking
Browse files

Separate sweeponce flow and optimize the flow

parent e12a6be2
...@@ -265,7 +265,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -265,7 +265,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
var_thread_buf(I) = type_convert<AccDataType>(0.0f); var_thread_buf(I) = type_convert<AccDataType>(0.0f);
}); });
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) // Separate sweep once and sweep twice pipeline
if constexpr(SweepOnce)
{ {
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
...@@ -273,55 +274,43 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -273,55 +274,43 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf(i)); x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
});
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
if constexpr(!SweepOnce)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
}
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k, threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf, gamma_global_val_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
gamma_thread_buf(i)); gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
thread_copy_fwd_step_m_k);
if constexpr(i != ThreadBufferNumber - 1)
{
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
}
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(iK0));
if constexpr(iK0 != ThreadBufferNumber - 1)
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
...@@ -331,33 +320,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -331,33 +320,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; divisor;
// gamma // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) * y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{}); gamma_thread_buf(iK0)(Number<offset_m_k>{}) +
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{}); beta_thread_buf(iK0)(Number<offset_m_k>{});
}); });
}); });
...@@ -369,16 +335,128 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -369,16 +335,128 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
y_thread_buf(i), y_thread_buf(i),
y_grid_desc_m_k, y_grid_desc_m_k,
y_global_val_buf); y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
if constexpr(i != ThreadBufferNumber - 1)
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
thread_copy_fwd_step_m_k);
}); });
} // end of sweep once
else
{
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
});
}
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, if constexpr(I > 0)
2 * thread_copy_bwd_step_m_k); block_sync_lds();
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k); int count = threadwise_welford.cur_count_;
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
} });
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf(i),
y_grid_desc_m_k,
y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
}
} // end of sweep twice
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment