// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/is_static.hpp" #include "ck/utility/print.hpp" #include "ck/utility/integral_constant.hpp" #include "ck/utility/sequence.hpp" #include "ck/utility/type.hpp" #include "ck/utility/enable_if.hpp" namespace ck { namespace detail { template struct TupleElementKey { __host__ __device__ constexpr TupleElementKey() = default; }; template struct TupleElementKeyData { using DataType = Data; #if 0 // workaround compiler complaint about implicitly-deleted default constructor __host__ __device__ constexpr TupleElementKeyData() = default; #else __host__ __device__ constexpr TupleElementKeyData() : mData{} {} #endif template , TupleElementKeyData>::value, bool>::type = false> __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward(v)) { } DataType mData; }; // for read access of tuple element template __host__ __device__ constexpr const Data& get_tuple_element_data_reference(const TupleElementKeyData& x) { return static_cast(x.mData); } // for write access of tuple element template __host__ __device__ constexpr Data& get_tuple_element_data_reference(TupleElementKeyData& x) { return x.mData; } // TODO: not sure the use of reference is correct template __host__ __device__ constexpr Data&& get_tuple_element_data_reference(TupleElementKeyData&& x) { return static_cast(x.mData); } // for infering type of tuple element template __host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData& x) { return std::forward(x.mData); } template struct TupleImpl; template struct TupleImpl, Xs...> : TupleElementKeyData, Xs>... { __host__ __device__ constexpr TupleImpl() = default; template , TupleImpl>::value, bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) : TupleElementKeyData, Xs>(std::forward(y))... { } template = 2, bool>::type = false> __host__ __device__ constexpr TupleImpl(Ys&&... ys) : TupleElementKeyData, Xs>(std::forward(ys))... { static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), "wrong! inconsistent size"); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } template __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey) const { return get_tuple_element_data_reference>(*this); } template __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) { return get_tuple_element_data_reference>(*this); } }; } // namespace detail template struct Tuple : detail::TupleImpl::type, Xs...> { using base = detail::TupleImpl::type, Xs...>; __host__ __device__ constexpr Tuple() = default; template , Tuple>::value, bool>::type = false> __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { } template = 2, bool>::type = false> __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward(ys)...) { } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } // read access template __host__ __device__ constexpr const auto& At() const { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); } // write access template __host__ __device__ constexpr auto& At() { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); } // read access template __host__ __device__ constexpr const auto& At(Number) const { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); } // write access template __host__ __device__ constexpr auto& At(Number) { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); } // read access template __host__ __device__ constexpr const auto& operator[](Number i) const { return At(i); } // write access template __host__ __device__ constexpr auto& operator()(Number i) { return At(i); } // WARNING: needed by compiler for C++ structured binding support only, don't use this function! template __host__ __device__ constexpr const auto& get() const { return this->template At(); } // WARNING: needed bu compiler for C++ structured binding support only, don't use this function! template __host__ __device__ constexpr auto& get() { return this->template At(); } template __host__ __device__ constexpr auto operator=(const T& a) { static_assert(T::Size() == Size(), "wrong! size not the same"); static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); return *this; } __host__ __device__ static constexpr bool IsStatic() { bool flag = true; static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) { flag &= is_static_v>>; }); return flag; } // FIXME: remove __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ void Print() const { printf("Tuple{size: %d, data: [", static_cast(Size())); static_for<0, Size(), 1>{}([&](auto i) { print(At(i)); if(i < Size() - 1) { printf(", "); } }); printf("]}"); } }; template <> struct Tuple<> { __host__ __device__ constexpr Tuple() = default; __host__ __device__ static constexpr index_t Size() { return 0; } template __host__ __device__ constexpr auto operator=(const T&) { return *this; } __host__ __device__ static constexpr bool IsStatic() { return true; } // FIXME: remove __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } }; template __host__ __device__ constexpr bool operator==(const Tuple& a, const Tuple& b) { bool same = true; static_for<0, sizeof...(Xs), 1>{}([&](auto i) { if(a[i] != b[i]) { same = false; } }); return same; } template __host__ __device__ constexpr bool operator!=(const Tuple& a, const Tuple& b) { return !(a == b); } template struct tuple_element { // type should keep the cv/ref qualifier of original tuple element using type = decltype(detail::get_tuple_element_data>(TTuple{})); }; template using tuple_element_t = typename tuple_element::type; template __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { return Tuple...>(std::forward(xs)...); } // https://en.cppreference.com/w/cpp/utility/tuple/tie template constexpr Tuple tie(Args&... args) noexcept { return {args...}; } } // namespace ck namespace std { // WARNING: needed by compiler for C++ structured binding support only, don't use this template struct tuple_size> : std::integral_constant { }; // WARNING: needed by compiler for C++ structured binding support only, don't use this template struct tuple_element> : ck::tuple_element> { }; } // namespace std