Commit 4ee40bcc authored by letaoqin's avatar letaoqin
Browse files

change warp_welford.hpp

parent 63214d01
...@@ -44,9 +44,9 @@ struct WarpMergeWelford ...@@ -44,9 +44,9 @@ struct WarpMergeWelford
constexpr index_t idim_p_lane = NDimP - 1; constexpr index_t idim_p_lane = NDimP - 1;
const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id()); // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
const auto rs_idx = // const auto rs_idx =
mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
...@@ -78,13 +78,15 @@ struct WarpMergeWelford ...@@ -78,13 +78,15 @@ struct WarpMergeWelford
// reduction sweep forward // reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) { static_for<0, nstage, 1>{}([&](auto istage) {
constexpr index_t lid_delta = // xor
lid_over_rid_derivative * (1 << (nstage - istage - 1)); index_t src_lane =
(__lane_id()) ^
(number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane // pull data from remote lane
const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// welford merge // welford merge
Merge(v_local_mean, Merge(v_local_mean,
...@@ -97,48 +99,6 @@ struct WarpMergeWelford ...@@ -97,48 +99,6 @@ struct WarpMergeWelford
} }
}); });
// cross-lane broadcast for replication
// only broadcast on R dimension correspond to lane
// (lane id maps to this R dimension)
if constexpr(BroadcastLane)
{
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
const index_t r_id = rs_idx[idim_r];
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// broadcast sweep backward
static_for<0, nstage, 1>{}([&](auto istage) {
// do I hold reduced data?
const bool do_i_hold_reduced_data = r_id < (1 << istage);
constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
// pull data from remote lane
const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta);
const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta);
const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta);
// decide whether to update local data with remote data
v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean;
v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var;
v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count;
});
}
});
}
mean_tensor.get_thread_buffer()(i) = v_local_mean; mean_tensor.get_thread_buffer()(i) = v_local_mean;
if constexpr(GetActualVariance) if constexpr(GetActualVariance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment