#ifndef CK_BUFFER_HPP #define CK_BUFFER_HPP #include "amd_buffer_addressing.hpp" #include "c_style_pointer_cast.hpp" #include "enable_if.hpp" namespace ck { template struct DynamicBuffer { using type = T; T* p_data_; ElementSpaceSize element_space_size_; T invalid_element_value_ = T{0}; __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) : p_data_{p_data}, element_space_size_{element_space_size} { } __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size, T invalid_element_value) : p_data_{p_data}, element_space_size_{element_space_size}, invalid_element_value_{invalid_element_value} { } __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() { return BufferAddressSpace; } __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } template >>::type, typename scalar_type>>::type>::value, bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); #if CK_USE_AMD_BUFFER_ADDRESSING bool constexpr use_amd_buffer_addressing = true; #else bool constexpr use_amd_buffer_addressing = false; #endif if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; if constexpr(InvalidElementUseNumericalZeroValue) { return amd_buffer_load_invalid_element_return_return_zero< remove_cv_t>, t_per_x>(p_data_, i, is_valid_element, element_space_size_); } else { return amd_buffer_load_invalid_element_return_customized_value< remove_cv_t>, t_per_x>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } else { if constexpr(InvalidElementUseNumericalZeroValue) { return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) : X{0}; } else { return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) : X{invalid_element_value_}; } } } template >>::type, typename scalar_type>>::type>::value, bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_store>, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); #else if(is_valid_element) { *c_style_pointer_cast(&p_data_[i]) = x; } #endif } else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) { if(is_valid_element) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE *c_style_pointer_cast(&p_data_[i]) = x; #else // HACK: compiler would lower IR "store address_space(3)" into // inefficient // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix if constexpr(is_same>>::type, int8_t>::value) { static_assert( (is_same>, int8_t>::value && is_same>, int8_t>::value) || (is_same>, int8_t>::value && is_same>, int8x2_t>::value) || (is_same>, int8_t>::value && is_same>, int8x4_t>::value) || (is_same>, int8x4_t>::value && is_same>, int8x4_t>::value) || (is_same>, int8x8_t>::value && is_same>, int8x8_t>::value) || (is_same>, int8x16_t>::value && is_same>, int8x16_t>::value), "wrong! not implemented for this combination, please add " "implementation"); if constexpr(is_same>, int8_t>::value && is_same>, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8_t>::value && is_same>, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8_t>::value && is_same>, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8x4_t>::value && is_same>, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8x8_t>::value && is_same>, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8x16_t>::value && is_same>, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } } else { *c_style_pointer_cast(&p_data_[i]) = x; } #endif } } else { if(is_valid_element) { *c_style_pointer_cast(&p_data_[i]) = x; } } } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { return DynamicBuffer{p, element_space_size}; } template < AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize, typename X, typename enable_if, remove_cvref_t>::value, bool>::type = false> __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { return DynamicBuffer{ p, element_space_size, invalid_element_value}; } } // namespace ck #endif