Commit ffa2c520 authored by Chao Liu's avatar Chao Liu
Browse files

refactoring tuple

parent 6cd94d98
......@@ -21,7 +21,7 @@ struct DynamicNativeTensorDescriptor
}
__host__ __device__ explicit constexpr DynamicNativeTensorDescriptor()
: lengths_{make_zero_array<index_t, NDim>()}, strides_{make_zero_array<index_t, NDim>()}
: lengths_{make_zero_multi_index<NDim>()}, strides_{make_zero_multi_index<NDim>()}
{
}
......
......@@ -408,13 +408,13 @@ transform_dynamic_tensor_descriptor_v2(const OldTensorDescriptor& old_tensor_des
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together
const auto all_transforms = merge_tuples(old_tensor_desc.GetTransforms(), new_transforms);
const auto all_transforms = tuple_cat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
merge_tuples(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
tuple_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
merge_tuples(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
tuple_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
return DynamicTensorDescriptor_v2<decltype(all_transforms),
decltype(all_low_dim_hidden_idss),
......
......@@ -5,7 +5,7 @@
namespace ck {
#if 1
#if 1 // dyanmically indexed array
template <index_t N>
using MultiIndex = Array<index_t, N>;
......@@ -22,7 +22,7 @@ __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
return make_array<const index_t>(std::forward<const Xs>(xs)...);
}
#endif
#else
#else // statically index array
template <index_t N>
using MultiIndex = StaticallyIndexedArray<index_t, N>;
......
......@@ -16,9 +16,9 @@ template <index_t... Is>
struct unpack_impl<Sequence<Is...>>
{
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F f, const X& x) const
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{
return f(x.At(Number<Is>{})...);
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
}
};
......@@ -30,26 +30,32 @@ template <index_t... Is, index_t... Js>
struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
{
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F f, const X& x, const Y& y) const
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
return f(x.At(Number<Is>{})..., y.At(Number<Js>{})...);
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...);
}
};
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F f, const X& x)
__host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x);
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto unpack(F f, const X& x, const Y& y)
__host__ __device__ constexpr auto unpack(F&& f, X&& x, Y&& y)
{
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y::Size(), 1>::type>{}(f, x, y);
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
} // namespace ck
......
......@@ -24,6 +24,10 @@ struct TupleElement
{
}
__host__ __device__ explicit constexpr TupleElement(const TupleElement&) = default;
__host__ __device__ explicit constexpr TupleElement(TupleElement&&) = default;
Data mData;
};
......@@ -39,11 +43,14 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x
return x.mData;
}
#if 0
// TODO: not sure the use of reference is correct
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{
return static_cast<Data&&>(x.mData);
}
#endif
template <typename Indices, typename... Xs>
struct TupleImpl;
......@@ -53,14 +60,21 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
{
static_assert(sizeof...(Is) == sizeof...(Xs), "wrong! inconsistent size");
}
template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
}
__host__ __device__ explicit constexpr TupleImpl(const TupleImpl&) = default;
__host__ __device__ explicit constexpr TupleImpl(TupleImpl&&) = default;
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
......@@ -89,6 +103,10 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
{
}
__host__ __device__ explicit constexpr Tuple(const Tuple&) = default;
__host__ __device__ explicit constexpr Tuple(Tuple&&) = default;
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
......
......@@ -13,9 +13,10 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
}
template <typename... Tuples>
__host__ __device__ constexpr auto merge_tuples(Tuples... tuples)
__host__ __device__ constexpr auto tuple_cat(Tuples&&... tuples)
{
return unpack([&tuples...](auto... xs) { return make_tuple(xs...); }, tuples...);
return unpack([&](auto&&... xs) { return make_tuple(std::forward<decltype(xs)>(xs)...); },
std::forward<Tuples>(tuples)...);
}
namespace detail {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment