#ifndef CK_DYNAMIC_BUFFER_HPP #define CK_DYNAMIC_BUFFER_HPP namespace ck { #include "amd_buffer_addressing_v2.hpp" template struct DynamicBuffer { using type = T; T* p_data_; ElementSpaceSize element_space_size_; __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) : p_data_{p_data}, element_space_size_{element_space_size} { } __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() { return BufferAddressSpace; } __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } template >>::type, typename scalar_type>>::type>::value, bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING return amd_buffer_load_v2>, t_per_x>( p_data_, i, is_valid_offset, element_space_size_); #else return is_valid_offset ? *reinterpret_cast(&p_data_[i]) : X{0}; #endif } else { return is_valid_offset ? *reinterpret_cast(&p_data_[i]) : X{0}; } } template >>::type, typename scalar_type>>::type>::value, bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING amd_buffer_store_v2>, t_per_x>( x, p_data_, i, is_valid_offset, element_space_size_); #else if(is_valid_offset) { *reinterpret_cast(&p_data_[i]) = x; } #endif } else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) { if(is_valid_offset) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE *reinterpret_cast(&p_data_[i]) = x; #else // HACK: compiler would lower IR "store address_space(3)" into // inefficient // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix if constexpr(is_same>>::type, int8_t>::value) { static_assert( (is_same>, int8_t>::value && is_same>, int8_t>::value) || (is_same>, int8_t>::value && is_same>, int8x2_t>::value) || (is_same>, int8_t>::value && is_same>, int8x4_t>::value) || (is_same>, int8x4_t>::value && is_same>, int8x4_t>::value) || (is_same>, int8x8_t>::value && is_same>, int8x8_t>::value) || (is_same>, int8x16_t>::value && is_same>, int8x16_t>::value), "wrong! not implemented for this combination, please add " "implementation"); if constexpr(is_same>, int8_t>::value && is_same>, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } else if constexpr(is_same>, int8_t>::value && is_same>, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } else if constexpr(is_same>, int8_t>::value && is_same>, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } else if constexpr(is_same>, int8x4_t>::value && is_same>, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } else if constexpr(is_same>, int8x8_t>::value && is_same>, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } else if constexpr(is_same>, int8x16_t>::value && is_same>, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } } else { *reinterpret_cast(&p_data_[i]) = x; } #endif } } else { if(is_valid_offset) { *reinterpret_cast(&p_data_[i]) = x; } } } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { return DynamicBuffer{p, element_space_size}; } } // namespace ck #endif