// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "statically_indexed_array.hpp" namespace ck { // static buffer for scalar template // TODO remove this bool, no longer needed struct StaticBuffer : public StaticallyIndexedArray, N> { using S = remove_cvref_t; using type = S; using base = StaticallyIndexedArray; __host__ __device__ constexpr StaticBuffer() : base{} {} template __host__ __device__ constexpr StaticBuffer& operator=(const Tuple& y) { static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same"); StaticBuffer& x = *this; static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; }); return x; } __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } __host__ __device__ static constexpr index_t Size() { return N; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } // read access to scalar template __host__ __device__ constexpr const S& operator[](Number i) const { return base::operator[](i); } // write access to scalar template __host__ __device__ constexpr S& operator()(Number i) { return base::operator()(i); } // Get a vector (type X) // "is" is offset of S, not X. // "is" should be aligned to X template ::value, bool>::type = false> __host__ __device__ constexpr remove_reference_t GetAsType(Number is) const { using X = remove_cvref_t; constexpr index_t kSPerX = scalar_type::vector_size; static_assert(Is % kSPerX == 0, "wrong! \"Is\" should be aligned to X"); vector_type vx; static_for<0, kSPerX, 1>{}( [&](auto j) { vx.template AsType()(j) = base::operator[](is + j); }); return vx.template AsType().template At<0>(); } // Set a vector (type X) // "is" is offset of S, not X. // "is" should be aligned to X template ::value, bool>::type = false> __host__ __device__ constexpr void SetAsType(Number is, X_ x) { using X = remove_cvref_t; constexpr index_t kSPerX = scalar_type::vector_size; static_assert(Is % kSPerX == 0, "wrong! \"Is\" should be aligned to X"); const vector_type vx{x}; static_for<0, kSPerX, 1>{}( [&](auto j) { base::operator()(is + j) = vx.template AsType()[j]; }); } __host__ __device__ void Initialize(const S& x) { static_for<0, N, 1>{}([&](auto i) { operator()(i) = S{x}; }); } // FIXME: deprecated __host__ __device__ void Clear() { Initialize(0); } // FIXME: deprecated __host__ __device__ constexpr StaticBuffer& operator=(const S& v) { Initialize(v); return *this; } }; // static buffer for vector template ::value, bool>::type = false> struct StaticBufferTupleOfVector : public StaticallyIndexedArray, NumOfVector> { using V = typename vector_type::type; using base = StaticallyIndexedArray, NumOfVector>; static constexpr auto s_per_v = Number{}; static constexpr auto num_of_v_ = Number{}; static constexpr auto s_per_buf = s_per_v * num_of_v_; __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr index_t Size() { return s_per_buf; }; // Get S // i is offset of S template __host__ __device__ constexpr const S& operator[](Number i) const { constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; return base::operator[](i_v).template AsType()[i_s]; } // Set S // i is offset of S template __host__ __device__ constexpr S& operator()(Number i) { constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; return base::operator()(i_v).template AsType()(i_s); } // Get X // i is offset of S, not X. i should be aligned to X template ::value, bool>::type = false> __host__ __device__ constexpr auto GetAsType(Number i) const { constexpr auto s_per_x = Number>::vector_size>{}; static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X"); static_assert(i % s_per_x == 0, "wrong!"); constexpr auto i_v = i / s_per_v; constexpr auto i_x = (i % s_per_v) / s_per_x; return base::operator[](i_v).template AsType()[i_x]; } // Set X // i is offset of S, not X. i should be aligned to X template ::value, bool>::type = false> __host__ __device__ constexpr void SetAsType(Number i, X x) { constexpr auto s_per_x = Number>::vector_size>{}; static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X"); static_assert(i % s_per_x == 0, "wrong!"); constexpr auto i_v = i / s_per_v; constexpr auto i_x = (i % s_per_v) / s_per_x; base::operator()(i_v).template AsType()(i_x) = x; } // Get read access to vector_type V // i is offset of S, not V. i should be aligned to V template __host__ __device__ constexpr const auto& GetVectorTypeReference(Number i) const { static_assert(i % s_per_v == 0, "wrong!"); constexpr auto i_v = i / s_per_v; return base::operator[](i_v); } // Get write access to vector_type V // i is offset of S, not V. i should be aligned to V template __host__ __device__ constexpr auto& GetVectorTypeReference(Number i) { static_assert(i % s_per_v == 0, "wrong!"); constexpr auto i_v = i / s_per_v; return base::operator()(i_v); } __host__ __device__ void Clear() { constexpr index_t NumScalars = NumOfVector * ScalarPerVector; static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); } }; template __host__ __device__ constexpr auto make_static_buffer(Number) { return StaticBuffer{}; } template __host__ __device__ constexpr auto make_static_buffer(LongNumber) { return StaticBuffer{}; } } // namespace ck