Commit 260f0e93 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

removed GetIndices; refactored GetDstCoordinateResetStep

parent 6720ef75
......@@ -92,7 +92,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
remove_cv_t<decltype(dst_scalar_per_access)>>;
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, "Wrong! ");
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
......@@ -154,8 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
// TODO: wrap into a function, say SpaceFillingCurve::MoveCoordForward(dst_desc, dst_cood_, idx_1d)?
// TODO: Do we need the if-statement? GetForwardStep is not well-defined for the last access.
if constexpr(idx_1d.value != num_accesses - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
......@@ -197,60 +195,17 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
return reset_dst_data_step_;
}();
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
constexpr auto reset_step = SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, I0);
return reset_dst_data_step;
return reset_step;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
......
......@@ -38,13 +38,24 @@ struct SpaceFillingCurve
ScalarPerVector;
}
template <index_t AccessIdx1dBegin, index_t AccessIdx1dEnd>
static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dBegin>, Number<AccessIdx1dEnd>)
{
static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative");
static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0");
static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative");
static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0");
constexpr auto idx_begin = GetIndex(Number<AccessIdx1dBegin>{});
constexpr auto idx_end = GetIndex(Number<AccessIdx1dEnd>{});
return idx_end - idx_begin;
}
template <index_t AccessIdx1d>
static __device__ __host__ constexpr auto GetForwardStep(Number<AccessIdx1d>)
{
constexpr auto idx_curr = GetIndex(Number<AccessIdx1d>{});
constexpr auto idx_next = GetIndex(Number<AccessIdx1d + 1>{});
return idx_next - idx_curr;
static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0");
return GetStepBetween(Number<AccessIdx1d>{}, Number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
......@@ -52,24 +63,7 @@ struct SpaceFillingCurve
{
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
constexpr auto idx_curr = GetIndex(Number<AccessIdx1d>{});
constexpr auto idx_prev = GetIndex(Number<AccessIdx1d - 1>{});
return idx_prev - idx_curr;
}
/*
* \brief Get all the multi-dimensional indices between given access_id and next access_id.
*/
template <typename DimAccessOrderOfSubTensor=DimAccessOrder, index_t AccessIdx1d>
static __device__ __host__ constexpr auto GetIndices(Number<AccessIdx1d>)
{
constexpr auto base_index = GetIndex(Number<AccessIdx1d>{});
// TODO: Should we use a zig-zag space-filling-curve here?
using SubSpaceFillingCurve = SpaceFillingCurve<ScalarsPerAccess, DimAccessOrderOfSubTensor, typename uniform_sequence_gen<nDim, 1>::type>;
constexpr auto compute_index = [base_index](auto k) constexpr {
return SubSpaceFillingCurve::GetIndex(k) + base_index;
};
return generate_tuple(compute_index, Number<ScalarPerVector>{});
return GetStepBetween(Number<AccessIdx1d>{}, Number<AccessIdx1d - 1>{});
}
template <index_t AccessIdx1d>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment