Commit df73287b authored by Chao Liu's avatar Chao Liu
Browse files

rework sequence

parent 33b5a855
...@@ -57,7 +57,7 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0 ...@@ -57,7 +57,7 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
template <class Lengths, class Strides> template <class Lengths, class Strides>
struct ConstantTensorDescriptor struct ConstantTensorDescriptor
{ {
using Type = ConstantTensorDescriptor<Lengths, Strides>; using Type = ConstantTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize(); static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor() __host__ __device__ constexpr ConstantTensorDescriptor()
...@@ -195,19 +195,14 @@ struct ConstantTensorDescriptor ...@@ -195,19 +195,14 @@ struct ConstantTensorDescriptor
Number<unfold_stride>{} * Number<unfold_stride>{} *
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{}); reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
// left and right lengths // left and right
constexpr auto lengths_pair = GetLengths().Split(Number<IDim>{}); constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
constexpr auto left_lengths = lengths_pair.first; constexpr auto right = make_increasing_sequence(
constexpr auto right_lengths = lengths_pair.second.PopFront(); Number<IDim + 1>{}, Number<GetNumOfDimension()>{}, Number<1>{});
// left and right strides
constexpr auto strides_pair = GetStrides().Split(Number<IDim>{});
constexpr auto left_strides = strides_pair.first;
constexpr auto right_strides = strides_pair.second.PopFront();
return make_ConstantTensorDescriptor( return make_ConstantTensorDescriptor(
left_lengths.Append(fold_lengths).Append(right_lengths), GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)),
left_strides.Append(fold_strides).Append(right_strides)); GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right)));
} }
template <index_t FirstUnfoldDim, index_t LastUnfoldDim> template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
...@@ -228,40 +223,28 @@ struct ConstantTensorDescriptor ...@@ -228,40 +223,28 @@ struct ConstantTensorDescriptor
"wrong! dimensions to be unfolded need to be packed"); "wrong! dimensions to be unfolded need to be packed");
}); });
// lengths // left and right
constexpr auto lens_pair1 = Lengths{}.Split(Number<LastUnfoldDim + 1>{}); constexpr auto left =
make_increasing_sequence(Number<0>{}, Number<FirstUnfoldDim>{}, Number<1>{});
constexpr auto right_lengths = lens_pair1.second; constexpr auto middle = make_increasing_sequence(
Number<FirstUnfoldDim>{}, Number<LastUnfoldDim + 1>{}, Number<1>{});
constexpr auto lens_pair2 = lens_pair1.first.Split(Number<FirstUnfoldDim>{}); constexpr auto right = make_increasing_sequence(
Number<LastUnfoldDim + 1>{}, Number<GetNumOfDimension()>{}, Number<1>{});
constexpr auto left_lengths = lens_pair2.first;
// length and stride
constexpr auto fold_lengths = lens_pair2.second; constexpr index_t unfold_length = accumulate_on_sequence(
GetLengths().Extract(middle), std::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_length =
accumulate_on_sequence(fold_lengths, std::multiplies<index_t>{}, Number<1>{}); constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
constexpr auto new_lengths = return make_ConstantTensorDescriptor(GetLengths()
left_lengths.PopBack(Number<unfold_length>{}).Append(right_lengths); .Extract(left)
.PushBack(Number<unfold_length>{})
// strides .Append(GetLengths().Extract(right)),
constexpr auto strides_pair1 = Strides{}.Split(Number<LastUnfoldDim + 1>{}); GetStrides()
.Extract(left)
constexpr auto right_strides = strides_pair1.second; .PushBack(Number<unfold_stride>{})
.Append(GetStrides().Extract(right)));
constexpr auto strides_pair2 = strides_pair1.first.Split(Number<FirstUnfoldDim>{});
constexpr auto left_strides = strides_pair2.first;
constexpr auto fold_strides = strides_pair2.second;
constexpr index_t unfold_stride = fold_strides.Back();
constexpr auto new_strides =
left_strides.PushBack(Number<unfold_stride>{}).Append(right_strides);
return make_ConstantTensorDescriptor(new_lengths, new_strides);
} }
template <index_t... IRs> template <index_t... IRs>
......
...@@ -2,12 +2,10 @@ ...@@ -2,12 +2,10 @@
#include "constant_integral.hip.hpp" #include "constant_integral.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
struct EmptySequence;
template <index_t... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
using Type = Sequence<Is...>; using Type = Sequence;
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
...@@ -72,101 +70,62 @@ struct Sequence ...@@ -72,101 +70,62 @@ struct Sequence
return Sequence<Is..., Xs...>{}; return Sequence<Is..., Xs...>{};
} }
__host__ __device__ constexpr auto Append(EmptySequence) const;
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Number<Ns>...) const __host__ __device__ constexpr auto Extract(Number<Ns>...) const
{ {
return Sequence<Get(Number<Ns>{})...>{}; return Sequence<Get(Number<Ns>{})...>{};
} }
template <index_t N> template <index_t... Ns>
struct split_impl __host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
{
template <class FirstSeq, class SecondSeq>
__host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const
{ {
constexpr index_t new_first = FirstSeq{}.PushBack(Number<SecondSeq{}.Front()>{}); return Sequence<Get(Number<Ns>{})...>{};
constexpr index_t new_second = SecondSeq{}.PopFront();
static_if<(N > 0)>{}([&](auto fwd) {
return split_impl<N - 1>{}(new_first, fwd(new_second));
}).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); });
} }
}; };
// split one sequence to two sequnces: [0, I) and [I, mSize)
// return type is std::pair
template <index_t I>
__host__ __device__ constexpr auto Split(Number<I>) const;
template <index_t I, index_t X>
__host__ __device__ constexpr auto Modify(Number<I>, Number<X>) const
{
constexpr auto first_second = Split(Number<I>{});
constexpr auto left = first_second.first; template <class, class>
constexpr auto right = first_second.second.PopFront(); struct sequence_merge;
return left.PushBack(Number<X>{}).Append(right); template <index_t... Xs, index_t... Ys>
} struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
using Type = Sequence<Xs..., Ys...>;
}; };
struct EmptySequence template <index_t IBegin, index_t NSize, index_t Increment>
struct increasing_sequence_gen
{ {
__host__ __device__ static constexpr index_t GetSize() { return 0; } static constexpr index_t NSizeLeft = NSize / 2;
template <index_t I>
__host__ __device__ constexpr auto PushFront(Number<I>) const
{
return Sequence<I>{};
}
template <index_t I> using Type =
__host__ __device__ constexpr auto PushBack(Number<I>) const sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type,
{ typename increasing_sequence_gen<IBegin + NSizeLeft * Increment,
return Sequence<I>{}; NSize - NSizeLeft,
} Increment>::Type>;
template <class Seq>
__host__ __device__ constexpr Seq Append(Seq) const
{
return Seq{};
}
}; };
template <index_t... Is> template <index_t IBegin, index_t Increment>
__host__ __device__ constexpr auto Sequence<Is...>::Append(EmptySequence) const struct increasing_sequence_gen<IBegin, 1, Increment>
{ {
return Type{}; using Type = Sequence<IBegin>;
} };
// split one sequence to two sequnces: [0, I) and [I, mSize) template <index_t IBegin, index_t Increment>
// return type is std::pair struct increasing_sequence_gen<IBegin, 0, Increment>
template <index_t... Is>
template <index_t I>
__host__ __device__ constexpr auto Sequence<Is...>::Split(Number<I>) const
{ {
static_assert(I <= GetSize(), "wrong! split position is too high!"); using Type = Sequence<>;
};
static_if<(I == 0)>{}([&](auto fwd) { return std::make_pair(EmptySequence{}, fwd(Type{})); });
static_if<(I == GetSize())>{}(
[&](auto fwd) { return std::make_pair(Type{}, fwd(EmptySequence{})); });
static_if<(I > 0 && I < GetSize())>{}(
[&](auto fwd) { return split_impl<I>{}(EmptySequence{}, fwd(Type{})); });
}
#if 0
template <index_t IBegin, index_t IEnd, index_t Increment> template <index_t IBegin, index_t IEnd, index_t Increment>
__host__ __device__ auto make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>) __host__ __device__ constexpr auto
make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
{ {
static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!"); static_assert(IBegin <= IEnd && Increment > 0, "wrong!");
// not implemented constexpr index_t NSize = (IEnd - IBegin) / Increment;
return increasing_sequence_gen<IBegin, NSize, Increment>{};
} }
#endif
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
...@@ -222,7 +181,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>) ...@@ -222,7 +181,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
{ {
constexpr auto seq_x = Sequence<Xs...>{}; constexpr auto seq_x = Sequence<Xs...>{};
#if 0 #if 0 // doesn't compile
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){}; constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment