Commit a91b68df authored by Chao Liu's avatar Chao Liu
Browse files

DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element

parent 2cbabbba
...@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
......
...@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc, FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf; c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
...@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
......
...@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>, vector_type<FloatAcc, BlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize()> c_mr_nr_blk_desc.GetElementSpaceSize(),
true>
c_thread_buf; c_thread_buf;
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Number<M2>{}, Number<M2>{},
Number<1>{})); Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_; c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
......
...@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
......
...@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
......
...@@ -10,25 +10,25 @@ union BufferResource ...@@ -10,25 +10,25 @@ union BufferResource
{ {
// 128 bit SGPRs to supply buffer resource in buffer instructions // 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data; int32x4_t content;
StaticallyIndexedArray<T*, 2> address; StaticallyIndexedArray<T*, 2> address;
StaticallyIndexedArray<int32_t, 4> range; StaticallyIndexedArray<int32_t, 4> range;
StaticallyIndexedArray<int32_t, 4> config; StaticallyIndexedArray<int32_t, 4> config;
}; };
template <typename T> template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
{ {
BufferResource<T> wave_buffer_resource; BufferResource<T> wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave); wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T); wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
// wavewise setting (32 bit) // wavewise setting (32 bit)
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data; return wave_buffer_resource.content;
} }
// load // load
...@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset,
index_t src_thread_addr_offset, index_t src_wave_addr_offset)
index_t src_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
...@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset) index_t dst_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
...@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_v2(const T* p_src_wave, amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
index_t src_thread_data_offset, index_t src_thread_element_offset,
bool src_thread_data_valid, bool src_thread_element_valid,
index_t src_element_space) index_t src_element_space_size)
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space); make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl_v2<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#endif #endif
} }
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size,
T customized_value)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(customized_value);
}
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must be global memory // 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer. // 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ void __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data, T* p_dst_wave,
T* p_dst_wave, const index_t dst_thread_element_offset,
const index_t dst_thread_data_offset, const bool dst_thread_element_valid,
const bool dst_thread_data_valid, const index_t dst_element_space_size)
const index_t dst_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space); make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type; using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type; using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl_v2<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) if(dst_thread_element_valid)
{ {
amd_buffer_store_impl_v2<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
......
...@@ -6,34 +6,43 @@ ...@@ -6,34 +6,43 @@
namespace ck { namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize> template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer struct DynamicBuffer
{ {
using type = T; using type = T;
T* p_data_; T* p_data_;
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size} : p_data_{p_data}, element_space_size_{element_space_size}
{ {
} }
__host__ __device__ constexpr DynamicBuffer(T* p_data,
ElementSpaceSize element_space_size,
T invalid_element_value)
: p_data_{p_data},
element_space_size_{element_space_size},
invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X, template <typename X,
typename std::enable_if< typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector =
...@@ -45,20 +54,41 @@ struct DynamicBuffer ...@@ -45,20 +54,41 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; bool constexpr use_amd_buffer_addressing = true;
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
p_data_, i, is_valid_offset, element_space_size_);
#else #else
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0}; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(InvalidElementUseNumericalZeroValue)
{
return amd_buffer_load_invalid_element_return_return_zero<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
}
} }
else else
{ {
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0}; if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
}
else
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
: X{invalid_element_value_};
}
} }
} }
...@@ -67,7 +97,7 @@ struct DynamicBuffer ...@@ -67,7 +97,7 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector =
...@@ -84,10 +114,10 @@ struct DynamicBuffer ...@@ -84,10 +114,10 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_offset, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_offset) if(is_valid_element)
{ {
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
} }
...@@ -95,7 +125,7 @@ struct DynamicBuffer ...@@ -95,7 +125,7 @@ struct DynamicBuffer
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
{ {
if(is_valid_offset) if(is_valid_element)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
...@@ -185,7 +215,7 @@ struct DynamicBuffer ...@@ -185,7 +215,7 @@ struct DynamicBuffer
} }
else else
{ {
if(is_valid_offset) if(is_valid_element)
{ {
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
} }
...@@ -197,12 +227,18 @@ struct DynamicBuffer ...@@ -197,12 +227,18 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic, template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value};
} }
} // namespace ck } // namespace ck
......
...@@ -5,30 +5,66 @@ ...@@ -5,30 +5,66 @@
namespace ck { namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N> template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
index_t N,
bool InvalidElementUseNumericalZeroValue>
struct StaticBuffer : public StaticallyIndexedArray<T, N> struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
using type = T; using type = T;
using base = StaticallyIndexedArray<T, N>; using base = StaticallyIndexedArray<T, N>;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
: base{}, invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
template <index_t I>
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? At(i) : T{0};
}
else
{
return is_valid_element ? At(i) : invalid_element_value_;
}
}
template <index_t I>
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
{
if(is_valid_element)
{
At(i) = x;
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic, template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
typename T,
index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<BufferAddressSpace, T, N>{}; return StaticBuffer<BufferAddressSpace, T, N, true>{};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
{
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
} }
} // namespace ck } // namespace ck
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#define USE_MODE 1 #define USE_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment