Commit df73287b authored by Chao Liu's avatar Chao Liu
Browse files

rework sequence

parent 33b5a855
......@@ -57,7 +57,7 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor<Lengths, Strides>;
using Type = ConstantTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor()
......@@ -195,19 +195,14 @@ struct ConstantTensorDescriptor
Number<unfold_stride>{} *
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
// left and right lengths
constexpr auto lengths_pair = GetLengths().Split(Number<IDim>{});
constexpr auto left_lengths = lengths_pair.first;
constexpr auto right_lengths = lengths_pair.second.PopFront();
// left and right strides
constexpr auto strides_pair = GetStrides().Split(Number<IDim>{});
constexpr auto left_strides = strides_pair.first;
constexpr auto right_strides = strides_pair.second.PopFront();
// left and right
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
constexpr auto right = make_increasing_sequence(
Number<IDim + 1>{}, Number<GetNumOfDimension()>{}, Number<1>{});
return make_ConstantTensorDescriptor(
left_lengths.Append(fold_lengths).Append(right_lengths),
left_strides.Append(fold_strides).Append(right_strides));
GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)),
GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right)));
}
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
......@@ -228,40 +223,28 @@ struct ConstantTensorDescriptor
"wrong! dimensions to be unfolded need to be packed");
});
// lengths
constexpr auto lens_pair1 = Lengths{}.Split(Number<LastUnfoldDim + 1>{});
constexpr auto right_lengths = lens_pair1.second;
constexpr auto lens_pair2 = lens_pair1.first.Split(Number<FirstUnfoldDim>{});
constexpr auto left_lengths = lens_pair2.first;
constexpr auto fold_lengths = lens_pair2.second;
constexpr index_t unfold_length =
accumulate_on_sequence(fold_lengths, std::multiplies<index_t>{}, Number<1>{});
constexpr auto new_lengths =
left_lengths.PopBack(Number<unfold_length>{}).Append(right_lengths);
// strides
constexpr auto strides_pair1 = Strides{}.Split(Number<LastUnfoldDim + 1>{});
constexpr auto right_strides = strides_pair1.second;
constexpr auto strides_pair2 = strides_pair1.first.Split(Number<FirstUnfoldDim>{});
constexpr auto left_strides = strides_pair2.first;
constexpr auto fold_strides = strides_pair2.second;
constexpr index_t unfold_stride = fold_strides.Back();
constexpr auto new_strides =
left_strides.PushBack(Number<unfold_stride>{}).Append(right_strides);
return make_ConstantTensorDescriptor(new_lengths, new_strides);
// left and right
constexpr auto left =
make_increasing_sequence(Number<0>{}, Number<FirstUnfoldDim>{}, Number<1>{});
constexpr auto middle = make_increasing_sequence(
Number<FirstUnfoldDim>{}, Number<LastUnfoldDim + 1>{}, Number<1>{});
constexpr auto right = make_increasing_sequence(
Number<LastUnfoldDim + 1>{}, Number<GetNumOfDimension()>{}, Number<1>{});
// length and stride
constexpr index_t unfold_length = accumulate_on_sequence(
GetLengths().Extract(middle), std::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
return make_ConstantTensorDescriptor(GetLengths()
.Extract(left)
.PushBack(Number<unfold_length>{})
.Append(GetLengths().Extract(right)),
GetStrides()
.Extract(left)
.PushBack(Number<unfold_stride>{})
.Append(GetStrides().Extract(right)));
}
template <index_t... IRs>
......
......@@ -2,12 +2,10 @@
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
struct EmptySequence;
template <index_t... Is>
struct Sequence
{
using Type = Sequence<Is...>;
using Type = Sequence;
static constexpr index_t mSize = sizeof...(Is);
......@@ -72,101 +70,62 @@ struct Sequence
return Sequence<Is..., Xs...>{};
}
__host__ __device__ constexpr auto Append(EmptySequence) const;
template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Number<Ns>...) const
{
return Sequence<Get(Number<Ns>{})...>{};
}
template <index_t N>
struct split_impl
{
template <class FirstSeq, class SecondSeq>
__host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const
{
constexpr index_t new_first = FirstSeq{}.PushBack(Number<SecondSeq{}.Front()>{});
constexpr index_t new_second = SecondSeq{}.PopFront();
static_if<(N > 0)>{}([&](auto fwd) {
return split_impl<N - 1>{}(new_first, fwd(new_second));
}).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); });
}
};
// split one sequence to two sequnces: [0, I) and [I, mSize)
// return type is std::pair
template <index_t I>
__host__ __device__ constexpr auto Split(Number<I>) const;
template <index_t I, index_t X>
__host__ __device__ constexpr auto Modify(Number<I>, Number<X>) const
template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
{
constexpr auto first_second = Split(Number<I>{});
constexpr auto left = first_second.first;
constexpr auto right = first_second.second.PopFront();
return left.PushBack(Number<X>{}).Append(right);
return Sequence<Get(Number<Ns>{})...>{};
}
};
struct EmptySequence
{
__host__ __device__ static constexpr index_t GetSize() { return 0; }
template <class, class>
struct sequence_merge;
template <index_t I>
__host__ __device__ constexpr auto PushFront(Number<I>) const
{
return Sequence<I>{};
}
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
using Type = Sequence<Xs..., Ys...>;
};
template <index_t I>
__host__ __device__ constexpr auto PushBack(Number<I>) const
{
return Sequence<I>{};
}
template <index_t IBegin, index_t NSize, index_t Increment>
struct increasing_sequence_gen
{
static constexpr index_t NSizeLeft = NSize / 2;
template <class Seq>
__host__ __device__ constexpr Seq Append(Seq) const
{
return Seq{};
}
using Type =
sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type,
typename increasing_sequence_gen<IBegin + NSizeLeft * Increment,
NSize - NSizeLeft,
Increment>::Type>;
};
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Append(EmptySequence) const
template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 1, Increment>
{
return Type{};
}
using Type = Sequence<IBegin>;
};
// split one sequence to two sequnces: [0, I) and [I, mSize)
// return type is std::pair
template <index_t... Is>
template <index_t I>
__host__ __device__ constexpr auto Sequence<Is...>::Split(Number<I>) const
template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 0, Increment>
{
static_assert(I <= GetSize(), "wrong! split position is too high!");
static_if<(I == 0)>{}([&](auto fwd) { return std::make_pair(EmptySequence{}, fwd(Type{})); });
static_if<(I == GetSize())>{}(
[&](auto fwd) { return std::make_pair(Type{}, fwd(EmptySequence{})); });
static_if<(I > 0 && I < GetSize())>{}(
[&](auto fwd) { return split_impl<I>{}(EmptySequence{}, fwd(Type{})); });
}
using Type = Sequence<>;
};
#if 0
template <index_t IBegin, index_t IEnd, index_t Increment>
__host__ __device__ auto make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
__host__ __device__ constexpr auto
make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
{
static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!");
static_assert(IBegin <= IEnd && Increment > 0, "wrong!");
constexpr index_t NSize = (IEnd - IBegin) / Increment;
// not implemented
return increasing_sequence_gen<IBegin, NSize, Increment>{};
}
#endif
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
......@@ -222,7 +181,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
{
constexpr auto seq_x = Sequence<Xs...>{};
#if 0
#if 0 // doesn't compile
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment