Commit 758f6977 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

add SpaceFillingCurve::GetIndices()

parent 2db9e84d
...@@ -130,12 +130,15 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve ...@@ -130,12 +130,15 @@ struct ThreadwiseTensorSliceTransfer_v1r3_using_space_filling_curve
static_for<0, num_accesses, 1>{}([&](auto idx_1d) { static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); // constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
constexpr auto all_indices = SpaceFillingCurve::GetIndices(idx_1d);
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// constexpr index_t src_offset = src_desc.CalculateOffset(
// src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + all_indices[i]);
SrcData dst_v; SrcData dst_v;
......
#include "math.hpp" #include "math.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "sequence_helper.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "statically_indexed_array_multi_index.hpp" #include "statically_indexed_array_multi_index.hpp"
#include "tuple_helper.hpp" #include "tuple_helper.hpp"
...@@ -56,6 +57,21 @@ struct SpaceFillingCurve ...@@ -56,6 +57,21 @@ struct SpaceFillingCurve
return idx_prev - idx_curr; return idx_prev - idx_curr;
} }
/*
* \brief Get all the multi-dimensional indices between given access_id and next access_id.
*/
template <typename DimAccessOrderOfSubTensor=DimAccessOrder, index_t AccessIdx1d>
static __device__ __host__ constexpr auto GetIndices(Number<AccessIdx1d>)
{
constexpr auto base_index = GetIndex(Number<AccessIdx1d>{});
// TODO: Should we use a zig-zag space-filling-curve here?
using SubSpaceFillingCurve = SpaceFillingCurve<ScalarsPerAccess, DimAccessOrderOfSubTensor, typename uniform_sequence_gen<nDim, 1>::type>;
constexpr auto compute_index = [base_index](auto k) constexpr {
return SubSpaceFillingCurve::GetIndex(k) + base_index;
};
return generate_tuple(compute_index, Number<ScalarPerVector>{});
}
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>) static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>)
{ {
......
...@@ -29,9 +29,9 @@ void traverse_using_space_filling_curve() ...@@ -29,9 +29,9 @@ void traverse_using_space_filling_curve()
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
using TensorLengths = Sequence<4, 10, 9>; using TensorLengths = Sequence<16, 10, 9>;
using DimAccessOrder = Sequence<2, 0, 1>; using DimAccessOrder = Sequence<2, 0, 1>;
using ScalarsPerAccess = Sequence<1, 2, 3>; using ScalarsPerAccess = Sequence<4, 2, 3>;
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>; using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
constexpr auto expected = make_tuple(make_tuple(0, 0, 0), constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
...@@ -39,36 +39,36 @@ void traverse_using_space_filling_curve() ...@@ -39,36 +39,36 @@ void traverse_using_space_filling_curve()
make_tuple(0, 4, 0), make_tuple(0, 4, 0),
make_tuple(0, 6, 0), make_tuple(0, 6, 0),
make_tuple(0, 8, 0), make_tuple(0, 8, 0),
make_tuple(1, 8, 0), make_tuple(4, 8, 0),
make_tuple(1, 6, 0), make_tuple(4, 6, 0),
make_tuple(1, 4, 0), make_tuple(4, 4, 0),
make_tuple(1, 2, 0), make_tuple(4, 2, 0),
make_tuple(1, 0, 0), make_tuple(4, 0, 0),
make_tuple(2, 0, 0), make_tuple(8, 0, 0),
make_tuple(2, 2, 0), make_tuple(8, 2, 0),
make_tuple(2, 4, 0), make_tuple(8, 4, 0),
make_tuple(2, 6, 0), make_tuple(8, 6, 0),
make_tuple(2, 8, 0), make_tuple(8, 8, 0),
make_tuple(3, 8, 0), make_tuple(12, 8, 0),
make_tuple(3, 6, 0), make_tuple(12, 6, 0),
make_tuple(3, 4, 0), make_tuple(12, 4, 0),
make_tuple(3, 2, 0), make_tuple(12, 2, 0),
make_tuple(3, 0, 0), make_tuple(12, 0, 0),
make_tuple(3, 0, 3), make_tuple(12, 0, 3),
make_tuple(3, 2, 3), make_tuple(12, 2, 3),
make_tuple(3, 4, 3), make_tuple(12, 4, 3),
make_tuple(3, 6, 3), make_tuple(12, 6, 3),
make_tuple(3, 8, 3), make_tuple(12, 8, 3),
make_tuple(2, 8, 3), make_tuple(8, 8, 3),
make_tuple(2, 6, 3), make_tuple(8, 6, 3),
make_tuple(2, 4, 3), make_tuple(8, 4, 3),
make_tuple(2, 2, 3), make_tuple(8, 2, 3),
make_tuple(2, 0, 3), make_tuple(8, 0, 3),
make_tuple(1, 0, 3), make_tuple(4, 0, 3),
make_tuple(1, 2, 3), make_tuple(4, 2, 3),
make_tuple(1, 4, 3), make_tuple(4, 4, 3),
make_tuple(1, 6, 3), make_tuple(4, 6, 3),
make_tuple(1, 8, 3), make_tuple(4, 8, 3),
make_tuple(0, 8, 3), make_tuple(0, 8, 3),
make_tuple(0, 6, 3), make_tuple(0, 6, 3),
make_tuple(0, 4, 3), make_tuple(0, 4, 3),
...@@ -79,21 +79,21 @@ void traverse_using_space_filling_curve() ...@@ -79,21 +79,21 @@ void traverse_using_space_filling_curve()
make_tuple(0, 4, 6), make_tuple(0, 4, 6),
make_tuple(0, 6, 6), make_tuple(0, 6, 6),
make_tuple(0, 8, 6), make_tuple(0, 8, 6),
make_tuple(1, 8, 6), make_tuple(4, 8, 6),
make_tuple(1, 6, 6), make_tuple(4, 6, 6),
make_tuple(1, 4, 6), make_tuple(4, 4, 6),
make_tuple(1, 2, 6), make_tuple(4, 2, 6),
make_tuple(1, 0, 6), make_tuple(4, 0, 6),
make_tuple(2, 0, 6), make_tuple(8, 0, 6),
make_tuple(2, 2, 6), make_tuple(8, 2, 6),
make_tuple(2, 4, 6), make_tuple(8, 4, 6),
make_tuple(2, 6, 6), make_tuple(8, 6, 6),
make_tuple(2, 8, 6), make_tuple(8, 8, 6),
make_tuple(3, 8, 6), make_tuple(12, 8, 6),
make_tuple(3, 6, 6), make_tuple(12, 6, 6),
make_tuple(3, 4, 6), make_tuple(12, 4, 6),
make_tuple(3, 2, 6), make_tuple(12, 2, 6),
make_tuple(3, 0, 6)); make_tuple(12, 0, 6));
constexpr index_t num_accesses = SpaceFillingCurve::GetNumOfAccess(); constexpr index_t num_accesses = SpaceFillingCurve::GetNumOfAccess();
...@@ -128,4 +128,19 @@ void traverse_using_space_filling_curve() ...@@ -128,4 +128,19 @@ void traverse_using_space_filling_curve()
static_assert(forward_step[I1] == expected_step[I1]); static_assert(forward_step[I1] == expected_step[I1]);
static_assert(forward_step[I2] == expected_step[I2]); static_assert(forward_step[I2] == expected_step[I2]);
}); });
static_for<0, num_accesses - 1, 1>{}([&](auto i) {
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
printf("idx_1d = %d, idx_md = [%d, %d, %d]\n",
i.value,
idx_curr[I0],
idx_curr[I1],
idx_curr[I2]);
constexpr auto all_indices = SpaceFillingCurve::GetIndices<Sequence<0, 1, 2>>(i);
static_for<0, SpaceFillingCurve::ScalarPerVector, 1>{}([&](auto j) {
printf(" [%d, %d, %d]\n", all_indices[j][I0], all_indices[j][I1], all_indices[j][I2]);
});
});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment