"...git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "d49e6e454eca30912b06f0af7e2d5357afe430ec"
Commit a91b68df authored by Chao Liu's avatar Chao Liu
Browse files

DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element

parent 2cbabbba
......@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
......
......@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// initialize output thread tensor
......@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data
......
......@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize()>
c_mr_nr_blk_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
......@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
......
......@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_;
DstCoord dst_coord_;
......
......@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_;
DstCoord dst_coord_;
......
......@@ -10,25 +10,25 @@ union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data;
int32x4_t content;
StaticallyIndexedArray<T*, 2> address;
StaticallyIndexedArray<int32_t, 4> range;
StaticallyIndexedArray<int32_t, 4> config;
};
template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
{
BufferResource<T> wave_buffer_resource;
// wavewise base address (64 bit)
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit)
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
// wavewise setting (32 bit)
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data;
return wave_buffer_resource.content;
}
// load
......@@ -204,8 +204,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
......@@ -412,7 +411,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
template <typename T, index_t N>
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
......@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space);
make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0);
return src_thread_element_valid ? tmp : vector_t(0);
#endif
}
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size,
T customized_value)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(customized_value);
}
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_element_space)
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space);
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl_v2<scalar_t, vector_size>(
amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_data_valid)
if(dst_thread_element_valid)
{
amd_buffer_store_impl_v2<scalar_t, vector_size>(
amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
......
......@@ -6,34 +6,43 @@
namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer
{
using type = T;
T* p_data_;
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
}
__host__ __device__ constexpr DynamicBuffer(T* p_data,
ElementSpaceSize element_space_size,
T invalid_element_value)
: p_data_{p_data},
element_space_size_{element_space_size},
invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
......@@ -45,20 +54,41 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
p_data_, i, is_valid_offset, element_space_size_);
bool constexpr use_amd_buffer_addressing = true;
#else
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(InvalidElementUseNumericalZeroValue)
{
return amd_buffer_load_invalid_element_return_return_zero<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
}
else
{
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
return amd_buffer_load_invalid_element_return_customized_value<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
}
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
}
else
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
: X{invalid_element_value_};
}
}
}
......@@ -67,7 +97,7 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
......@@ -84,10 +114,10 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_offset, element_space_size_);
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
#else
if(is_valid_offset)
if(is_valid_element)
{
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
}
......@@ -95,7 +125,7 @@ struct DynamicBuffer
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
{
if(is_valid_offset)
if(is_valid_element)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
......@@ -185,7 +215,7 @@ struct DynamicBuffer
}
else
{
if(is_valid_offset)
if(is_valid_element)
{
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
}
......@@ -197,12 +227,18 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
typename T,
typename ElementSpaceSize>
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value};
}
} // namespace ck
......
......@@ -5,30 +5,66 @@
namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
index_t N,
bool InvalidElementUseNumericalZeroValue>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
: base{}, invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{
return BufferAddressSpace;
}
template <index_t I>
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? At(i) : T{0};
}
else
{
return is_valid_element ? At(i) : invalid_element_value_;
}
}
template <index_t I>
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
{
if(is_valid_element)
{
At(i) = x;
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
typename T,
index_t N>
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<BufferAddressSpace, T, N>{};
return StaticBuffer<BufferAddressSpace, T, N, true>{};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
{
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
}
} // namespace ck
......
......@@ -21,7 +21,7 @@
#define USE_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment