#pragma once #include "amd_buffer_addressing.hpp" #include "c_style_pointer_cast.hpp" #include "config.hpp" #include "enable_if.hpp" namespace ck { template struct DynamicBuffer { using type = T; T* p_data_; ElementSpaceSize element_space_size_; T invalid_element_value_ = T{0}; __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) : p_data_{p_data}, element_space_size_{element_space_size} { } __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size, T invalid_element_value) : p_data_{p_data}, element_space_size_{element_space_size}, invalid_element_value_{invalid_element_value} { } __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return BufferAddressSpace; } __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } template >::type, typename scalar_type>::type>::value, bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); #if CK_USE_AMD_BUFFER_LOAD bool constexpr use_amd_buffer_addressing = true; #else bool constexpr use_amd_buffer_addressing = false; #endif if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; if constexpr(InvalidElementUseNumericalZeroValue) { return amd_buffer_load_invalid_element_return_zero, t_per_x>( p_data_, i, is_valid_element, element_space_size_); } else { return amd_buffer_load_invalid_element_return_customized_value, t_per_x>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } else { if(is_valid_element) { #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS X tmp; __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); return tmp; #else return *c_style_pointer_cast(&p_data_[i]); #endif } else { if constexpr(InvalidElementUseNumericalZeroValue) { return X{0}; } else { return X{invalid_element_value_}; } } } } template >::type, typename scalar_type>::type>::value, bool>::type = false> __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::Set) { this->template Set(i, is_valid_element, x); } else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd) { this->template AtomicAdd(i, is_valid_element, x); } else if constexpr(Op == InMemoryDataOperationEnum::Add) { auto tmp = this->template Get(i, is_valid_element); this->template Set(i, is_valid_element, x + tmp); // tmp += x; // this->template Set(i, is_valid_element, tmp); } } template >::type, typename scalar_type>::type>::value, bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); #if CK_USE_AMD_BUFFER_STORE bool constexpr use_amd_buffer_addressing = true; #else bool constexpr use_amd_buffer_addressing = false; #endif #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE bool constexpr workaround_int8_ds_write_issue = true; #else bool constexpr workaround_int8_ds_write_issue = false; #endif if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_store, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && is_same>::type, int8_t>::value && workaround_int8_ds_write_issue) { if(is_valid_element) { // HACK: compiler would lower IR "store address_space(3)" into inefficient // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix static_assert((is_same, int8_t>::value && is_same, int8_t>::value) || (is_same, int8_t>::value && is_same, int8x2_t>::value) || (is_same, int8_t>::value && is_same, int8x4_t>::value) || (is_same, int8_t>::value && is_same, int8x8_t>::value) || (is_same, int8_t>::value && is_same, int8x16_t>::value) || (is_same, int8x4_t>::value && is_same, int8x4_t>::value) || (is_same, int8x8_t>::value && is_same, int8x8_t>::value) || (is_same, int8x16_t>::value && is_same, int8x16_t>::value), "wrong! not implemented for this combination, please add " "implementation"); if constexpr(is_same, int8_t>::value && is_same, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8_t>::value && is_same, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8_t>::value && is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8_t>::value && is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8_t>::value && is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8x4_t>::value && is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8x8_t>::value && is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } else if constexpr(is_same, int8x16_t>::value && is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } } } else { if(is_valid_element) { #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS X tmp = x; __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); #else *c_style_pointer_cast(&p_data_[i]) = x; #endif } } } template >::type, typename scalar_type>::type>::value, bool>::type = false> __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) { using scalar_t = typename scalar_type>::type; // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = is_same_v, int32_t> || is_same_v, float> || (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = is_same_v, float> || (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); #else bool constexpr use_amd_buffer_addressing = false; #endif if constexpr(use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_atomic_add, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); } else { if(is_valid_element) { // FIXME: atomicAdd is defined by HIP, need to avoid implicit type casting when // calling it atomicAdd(c_style_pointer_cast(&p_data_[i]), x); } } } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { return DynamicBuffer{p, element_space_size}; } template < AddressSpaceEnum BufferAddressSpace, typename T, typename ElementSpaceSize, typename X, typename enable_if, remove_cvref_t>::value, bool>::type = false> __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { return DynamicBuffer{ p, element_space_size, invalid_element_value}; } } // namespace ck