Commit ffa2c520 authored by Chao Liu's avatar Chao Liu
Browse files

refactoring tuple

parent 6cd94d98
...@@ -21,7 +21,7 @@ struct DynamicNativeTensorDescriptor ...@@ -21,7 +21,7 @@ struct DynamicNativeTensorDescriptor
} }
__host__ __device__ explicit constexpr DynamicNativeTensorDescriptor() __host__ __device__ explicit constexpr DynamicNativeTensorDescriptor()
: lengths_{make_zero_array<index_t, NDim>()}, strides_{make_zero_array<index_t, NDim>()} : lengths_{make_zero_multi_index<NDim>()}, strides_{make_zero_multi_index<NDim>()}
{ {
} }
......
...@@ -408,13 +408,13 @@ transform_dynamic_tensor_descriptor_v2(const OldTensorDescriptor& old_tensor_des ...@@ -408,13 +408,13 @@ transform_dynamic_tensor_descriptor_v2(const OldTensorDescriptor& old_tensor_des
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together // put everything together
const auto all_transforms = merge_tuples(old_tensor_desc.GetTransforms(), new_transforms); const auto all_transforms = tuple_cat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss = constexpr auto all_low_dim_hidden_idss =
merge_tuples(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); tuple_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss = constexpr auto all_up_dim_hidden_idss =
merge_tuples(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); tuple_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
return DynamicTensorDescriptor_v2<decltype(all_transforms), return DynamicTensorDescriptor_v2<decltype(all_transforms),
decltype(all_low_dim_hidden_idss), decltype(all_low_dim_hidden_idss),
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
#if 1 #if 1 // dyanmically indexed array
template <index_t N> template <index_t N>
using MultiIndex = Array<index_t, N>; using MultiIndex = Array<index_t, N>;
...@@ -22,7 +22,7 @@ __host__ __device__ constexpr auto make_multi_index(Xs&&... xs) ...@@ -22,7 +22,7 @@ __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
return make_array<const index_t>(std::forward<const Xs>(xs)...); return make_array<const index_t>(std::forward<const Xs>(xs)...);
} }
#endif #endif
#else #else // statically index array
template <index_t N> template <index_t N>
using MultiIndex = StaticallyIndexedArray<index_t, N>; using MultiIndex = StaticallyIndexedArray<index_t, N>;
......
...@@ -16,9 +16,9 @@ template <index_t... Is> ...@@ -16,9 +16,9 @@ template <index_t... Is>
struct unpack_impl<Sequence<Is...>> struct unpack_impl<Sequence<Is...>>
{ {
template <typename F, typename X> template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F f, const X& x) const __host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{ {
return f(x.At(Number<Is>{})...); return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
} }
}; };
...@@ -30,26 +30,32 @@ template <index_t... Is, index_t... Js> ...@@ -30,26 +30,32 @@ template <index_t... Is, index_t... Js>
struct unpack2_impl<Sequence<Is...>, Sequence<Js...>> struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
{ {
template <typename F, typename X, typename Y> template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F f, const X& x, const Y& y) const __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{ {
return f(x.At(Number<Is>{})..., y.At(Number<Js>{})...); return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...);
} }
}; };
} // namespace detail } // namespace detail
template <typename F, typename X> template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F f, const X& x) __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{ {
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x); using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
} }
// TODO: properly implement unpack that takes any number of containers // TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y> template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto unpack(F f, const X& x, const Y& y) __host__ __device__ constexpr auto unpack(F&& f, X&& x, Y&& y)
{ {
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type, using X_ = remove_reference_t<X>;
typename arithmetic_sequence_gen<0, Y::Size(), 1>::type>{}(f, x, y); using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
} }
} // namespace ck } // namespace ck
......
...@@ -24,6 +24,10 @@ struct TupleElement ...@@ -24,6 +24,10 @@ struct TupleElement
{ {
} }
__host__ __device__ explicit constexpr TupleElement(const TupleElement&) = default;
__host__ __device__ explicit constexpr TupleElement(TupleElement&&) = default;
Data mData; Data mData;
}; };
...@@ -39,11 +43,14 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x ...@@ -39,11 +43,14 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x
return x.mData; return x.mData;
} }
#if 0
// TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x) __host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
#endif
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
...@@ -53,14 +60,21 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -53,14 +60,21 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{ {
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()... __host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs), "wrong! inconsistent size");
} }
template <typename... Ys> template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys) __host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
} }
__host__ __device__ explicit constexpr TupleImpl(const TupleImpl&) = default;
__host__ __device__ explicit constexpr TupleImpl(TupleImpl&&) = default;
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
...@@ -89,6 +103,10 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -89,6 +103,10 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
{ {
} }
__host__ __device__ explicit constexpr Tuple(const Tuple&) = default;
__host__ __device__ explicit constexpr Tuple(Tuple&&) = default;
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
......
...@@ -13,9 +13,10 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>) ...@@ -13,9 +13,10 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
} }
template <typename... Tuples> template <typename... Tuples>
__host__ __device__ constexpr auto merge_tuples(Tuples... tuples) __host__ __device__ constexpr auto tuple_cat(Tuples&&... tuples)
{ {
return unpack([&tuples...](auto... xs) { return make_tuple(xs...); }, tuples...); return unpack([&](auto&&... xs) { return make_tuple(std::forward<decltype(xs)>(xs)...); },
std::forward<Tuples>(tuples)...);
} }
namespace detail { namespace detail {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment