// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP #include "common_header.hpp" namespace ck { template using MultiIndex = Array; template __host__ __device__ constexpr auto make_multi_index(Xs&&... xs) { return make_array(index_t{xs}...); } template __host__ __device__ constexpr auto make_zero_multi_index() { return unpack([](auto... xs) { return make_multi_index(xs...); }, typename uniform_sequence_gen::type{}); } template __host__ __device__ constexpr auto to_multi_index(const T& x) { return unpack([](auto... ys) { return make_multi_index(ys...); }, x); } template __host__ __device__ constexpr auto operator+=(MultiIndex& y, const X& x) { static_assert(X::Size() == NSize, "wrong! size not the same"); static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); return y; } template __host__ __device__ constexpr auto operator-=(MultiIndex& y, const X& x) { static_assert(X::Size() == NSize, "wrong! size not the same"); static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); return y; } template __host__ __device__ constexpr auto operator+(const MultiIndex& a, const T& b) { using type = MultiIndex; static_assert(T::Size() == NSize, "wrong! size not the same"); type r; static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; }); return r; } template __host__ __device__ constexpr auto operator-(const MultiIndex& a, const T& b) { using type = MultiIndex; static_assert(T::Size() == NSize, "wrong! size not the same"); type r; static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; }); return r; } template __host__ __device__ constexpr auto operator*(const MultiIndex& a, const T& b) { using type = MultiIndex; static_assert(T::Size() == NSize, "wrong! size not the same"); type r; static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; }); return r; } } // namespace ck #endif