Commit 0271338e authored by Chao Liu's avatar Chao Liu
Browse files

added ReorderGiveOld2New() in Sequence and ConstantTensorDescriptor

parent fdcfae3a
...@@ -419,6 +419,13 @@ struct ConstantTensorDescriptor ...@@ -419,6 +419,13 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})), return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{}; decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
} }
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
}
}; };
template <class Lengths> template <class Lengths>
......
...@@ -74,8 +74,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -74,8 +74,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
constexpr auto data_multi_id_in_access_order = constexpr auto data_multi_id_in_access_order =
access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{}); access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
constexpr auto data_multi_id = reorder_array_given_old2new( constexpr auto data_multi_id =
sequence2array(data_multi_id_in_access_order), DimAccessOrder{}); data_multi_id_in_access_order.ReorderGivenOld2New(DimAccessOrder{});
const index_t src_index = const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
......
...@@ -6,12 +6,27 @@ ...@@ -6,12 +6,27 @@
namespace ck { namespace ck {
template <index_t...>
struct Sequence;
template <class Seq, index_t I>
struct sequence_split;
template <class> template <class>
struct is_valid_sequence_map; struct sequence_reverse;
template <class> template <class>
struct sequence_map_inverse; struct sequence_map_inverse;
template <class>
struct is_valid_sequence_map;
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
...@@ -71,7 +86,10 @@ struct Sequence ...@@ -71,7 +86,10 @@ struct Sequence
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{}); return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
} }
__host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr auto Reverse()
{
return typename sequence_reverse<Type>::type{};
}
__host__ __device__ static constexpr auto Front() __host__ __device__ static constexpr auto Front()
{ {
...@@ -85,9 +103,9 @@ struct Sequence ...@@ -85,9 +103,9 @@ struct Sequence
return Get(Number<mSize - 1>{}); return Get(Number<mSize - 1>{});
} }
__host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
__host__ __device__ static constexpr auto PopBack(); __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>) __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
...@@ -126,7 +144,16 @@ struct Sequence ...@@ -126,7 +144,16 @@ struct Sequence
} }
template <index_t I, index_t X> template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>); __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <class F> template <class F>
__host__ __device__ static constexpr auto Transform(F f) __host__ __device__ static constexpr auto Transform(F f)
...@@ -283,7 +310,8 @@ template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain> ...@@ -283,7 +310,8 @@ template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl struct sequence_map_inverse_impl
{ {
private: private:
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y{}[XBegin], XBegin); static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
public: public:
using type = using type =
...@@ -417,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>) ...@@ -417,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
template <class Seq> template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq) __host__ __device__ constexpr auto sequence_pop_back(Seq)
{ {
static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq{}.Reverse()).Reverse(); return sequence_pop_front(Seq::Reverse()).Reverse();
} }
template <class F, index_t... Xs> template <class F, index_t... Xs>
...@@ -458,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I ...@@ -458,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront()
{
return sequence_pop_front(Type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack()
{
return sequence_pop_back(Type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
{
return typename sequence_reverse<Sequence<Is...>>::type{};
}
template <index_t... Is>
template <index_t I, index_t X>
__host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>) __host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment