Commit 8b1f2238 authored by rocking's avatar rocking
Browse files

Share the VGPR of gamma and beta

parent f61c37d0
...@@ -133,14 +133,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -133,14 +133,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
}, },
Number<ThreadBufferNumber>{}); Number<ThreadBufferNumber>{});
auto beta_thread_buf = generate_tuple( auto& beta_thread_buf = gamma_thread_buf;
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto y_thread_buf = generate_tuple( auto y_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
...@@ -314,15 +307,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -314,15 +307,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(iK0));
if constexpr(iK0 != ThreadBufferNumber - 1)
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
...@@ -335,7 +319,32 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -335,7 +319,32 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) * y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{}) + gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
if constexpr(i != ThreadBufferNumber - 1)
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{}); beta_thread_buf(iK0)(Number<offset_m_k>{});
}); });
}); });
......
...@@ -143,15 +143,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -143,15 +143,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}, },
Number<ThreadBufferNumber>{}); Number<ThreadBufferNumber>{});
auto beta_thread_buf = generate_tuple( auto& beta_thread_buf = gamma_thread_buf;
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto& y_thread_buf = x_thread_buf; auto& y_thread_buf = x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
...@@ -292,15 +284,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -292,15 +284,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(iK0));
if constexpr(iK0 != ThreadBufferNumber - 1)
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
...@@ -313,7 +296,32 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -313,7 +296,32 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) * y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{}) + gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
if constexpr(i != ThreadBufferNumber - 1)
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{}); beta_thread_buf(iK0)(Number<offset_m_k>{});
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment