Commit e371df51 authored by Chao Liu's avatar Chao Liu
Browse files

use buffer load OOB check for padding

parent 7a929377
...@@ -112,17 +112,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -112,17 +112,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src data's valid mapping situation, only check the first data in this src // Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) transfer_data<SrcData,
{ SrcDataPerRead,
transfer_data<SrcData, SrcAddressSpace,
SrcDataPerRead, AddressSpace::Vgpr,
SrcAddressSpace, InMemoryDataOperation::Set,
AddressSpace::Vgpr, SrcDataStride,
InMemoryDataOperation::Set, 1>(p_src,
SrcDataStride, src_coord.GetOffset(),
1>( src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src_long_vector,
} buffer_offset,
true);
} }
// SrcData to DstData conversion // SrcData to DstData conversion
...@@ -146,17 +147,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -146,17 +147,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check dst data's valid mapping situation, only check the first data in this dst // Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) transfer_data<DstData,
{ DstDataPerWrite,
transfer_data<DstData, AddressSpace::Vgpr,
DstDataPerWrite, DstAddressSpace,
AddressSpace::Vgpr, DstInMemOp,
DstAddressSpace, 1,
DstInMemOp, DstDataStride>(p_dst_long_vector,
1, buffer_offset,
DstDataStride>( true,
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst,
} dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid());
} }
}); });
} }
...@@ -266,18 +268,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -266,18 +268,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src // src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) transfer_data<SrcData,
{ SrcDataPerRead,
transfer_data<SrcData, SrcAddressSpace,
SrcDataPerRead, AddressSpace::Vgpr,
SrcAddressSpace, InMemoryDataOperation::Set>(
AddressSpace::Vgpr, p_src,
InMemoryDataOperation::Set>(p_src, src_nonlinear_coord.GetOffset() + src_linear_offset,
src_nonlinear_coord.GetOffset() + src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
src_linear_offset, p_src_long_vector,
p_src_long_vector, buffer_offset,
buffer_offset); true);
}
} }
// SrcData to DstData conversion // SrcData to DstData conversion
...@@ -305,15 +306,16 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -305,15 +306,16 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst // dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) transfer_data<DstData,
{ DstDataPerWrite,
transfer_data<DstData, AddressSpace::Vgpr,
DstDataPerWrite, DstAddressSpace,
AddressSpace::Vgpr, DstInMemOp>(p_dst_long_vector,
DstAddressSpace, buffer_offset,
DstInMemOp>( true,
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst,
} dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid());
} }
}); });
}); });
...@@ -405,15 +407,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -405,15 +407,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src // src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) transfer_data<SrcData,
{ SrcDataPerRead,
transfer_data<SrcData, SrcAddressSpace,
SrcDataPerRead, AddressSpace::Vgpr,
SrcAddressSpace, InMemoryDataOperation::Set>(
AddressSpace::Vgpr, p_src,
InMemoryDataOperation::Set>( src_coord.GetOffset(),
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
} p_src_long_vector,
buffer_offset,
true);
} }
// SrcData to DstData conversion // SrcData to DstData conversion
...@@ -450,18 +454,16 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -450,18 +454,16 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst // dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) transfer_data<DstData,
{ DstDataPerWrite,
transfer_data<DstData, AddressSpace::Vgpr,
DstDataPerWrite, DstAddressSpace,
AddressSpace::Vgpr, DstInMemOp>(p_dst_long_vector,
DstAddressSpace, buffer_offset,
DstInMemOp>(p_dst_long_vector, true,
buffer_offset, p_dst,
p_dst, dst_nonlinear_coord.GetOffset() + dst_linear_offset,
dst_nonlinear_coord.GetOffset() + dst_coord.IsOffsetValidAssumingUpperIndexIsValid());
dst_linear_offset);
}
} }
}); });
}); });
......
...@@ -150,8 +150,11 @@ __llvm_amdgcn_buffer_atomic_add_f32(float vdata, ...@@ -150,8 +150,11 @@ __llvm_amdgcn_buffer_atomic_add_f32(float vdata,
// 2) p_src to be a block-invariant pointer. // 2) p_src to be a block-invariant pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::MemoryType amd_buffer_load( __device__ typename vector_type<T, VectorSize>::MemoryType
const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset); amd_buffer_load(const T* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset,
bool src_valid);
// buffer_store requires: // buffer_store requires:
// 1) p_src must be in vgpr space, d_dst must be global memory // 1) p_src must be in vgpr space, d_dst must be global memory
...@@ -161,18 +164,21 @@ template <typename T, index_t VectorSize> ...@@ -161,18 +164,21 @@ template <typename T, index_t VectorSize>
__device__ void amd_buffer_store(const T* p_src, __device__ void amd_buffer_store(const T* p_src,
T* p_dst_block, T* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset); index_t dst_const_data_offset,
bool dst_valid);
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ void amd_buffer_atomic_add(const T* p_src, __device__ void amd_buffer_atomic_add(const T* p_src,
T* p_dst_block, T* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset); index_t dst_const_data_offset,
bool dst_valid);
template <> template <>
__device__ float amd_buffer_load<float, 1>(const float* p_src_block, __device__ float amd_buffer_load<float, 1>(const float* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<float> src_block_config; BufferAddressConfig<float> src_block_config;
...@@ -187,13 +193,18 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_block, ...@@ -187,13 +193,18 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_block,
index_t src_const_addr_offset = src_const_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
return __llvm_amdgcn_buffer_load_f32( return __llvm_amdgcn_buffer_load_f32(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
} }
template <> template <>
__device__ float2_t amd_buffer_load<float, 2>(const float* p_src_block, __device__ float2_t amd_buffer_load<float, 2>(const float* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<float> src_block_config; BufferAddressConfig<float> src_block_config;
...@@ -208,13 +219,18 @@ __device__ float2_t amd_buffer_load<float, 2>(const float* p_src_block, ...@@ -208,13 +219,18 @@ __device__ float2_t amd_buffer_load<float, 2>(const float* p_src_block,
index_t src_const_addr_offset = src_const_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
return __llvm_amdgcn_buffer_load_f32x2( return __llvm_amdgcn_buffer_load_f32x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
} }
template <> template <>
__device__ float4_t amd_buffer_load<float, 4>(const float* p_src_block, __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<float> src_block_config; BufferAddressConfig<float> src_block_config;
...@@ -229,13 +245,18 @@ __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_block, ...@@ -229,13 +245,18 @@ __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_block,
index_t src_const_addr_offset = src_const_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
return __llvm_amdgcn_buffer_load_f32x4( return __llvm_amdgcn_buffer_load_f32x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
} }
template <> template <>
__device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_block, __device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<half_t> src_block_config; BufferAddressConfig<half_t> src_block_config;
...@@ -251,16 +272,21 @@ __device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_block, ...@@ -251,16 +272,21 @@ __device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_block,
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t); index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
return __llvm_amdgcn_buffer_load_f16( return __llvm_amdgcn_buffer_load_f16(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
#else #else
return p_src_block[src_thread_data_offset + src_const_data_offset]; return src_valid ? p_src_block[src_thread_data_offset + src_const_data_offset] : 0;
#endif #endif
} }
template <> template <>
__device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block, __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<half_t> src_block_config; BufferAddressConfig<half_t> src_block_config;
...@@ -276,10 +302,18 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block, ...@@ -276,10 +302,18 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block,
#if !CK_WORKAROUND_SWDEV_231101 #if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_f16x2( return __llvm_amdgcn_buffer_load_f16x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
#else #else
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
return *reinterpret_cast<half2_t*>(&dst_out_tmp); return *reinterpret_cast<half2_t*>(&dst_out_tmp);
#endif #endif
...@@ -288,7 +322,8 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block, ...@@ -288,7 +322,8 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block,
template <> template <>
__device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block, __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<half_t> src_block_config; BufferAddressConfig<half_t> src_block_config;
...@@ -304,10 +339,18 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block, ...@@ -304,10 +339,18 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block,
#if !CK_WORKAROUND_SWDEV_231101 #if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_f16x4( return __llvm_amdgcn_buffer_load_f16x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
#else #else
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
return *reinterpret_cast<half4_t*>(&dst_out_tmp); return *reinterpret_cast<half4_t*>(&dst_out_tmp);
#endif #endif
...@@ -316,7 +359,8 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block, ...@@ -316,7 +359,8 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block,
template <> template <>
__device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_block, __device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<half_t> src_block_config; BufferAddressConfig<half_t> src_block_config;
...@@ -330,20 +374,21 @@ __device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_block, ...@@ -330,20 +374,21 @@ __device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_block,
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t); index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
#if !CK_WORKAROUND_SWDEV_231101
static_assert(false, "wrong! not supported");
#else
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
return *reinterpret_cast<half8_t*>(&dst_out_tmp); return *reinterpret_cast<half8_t*>(&dst_out_tmp);
#endif
} }
template <> template <>
__device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_block, __device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<ushort> src_block_config; BufferAddressConfig<ushort> src_block_config;
...@@ -359,16 +404,21 @@ __device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_block, ...@@ -359,16 +404,21 @@ __device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_block,
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort); index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
return __llvm_amdgcn_buffer_load_bf16( return __llvm_amdgcn_buffer_load_bf16(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
#else #else
return p_src_block[src_thread_data_offset + src_const_data_offset]; return src_valid ? p_src_block[src_thread_data_offset + src_const_data_offset] : 0;
#endif #endif
} }
template <> template <>
__device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block, __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<ushort> src_block_config; BufferAddressConfig<ushort> src_block_config;
...@@ -384,10 +434,18 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block, ...@@ -384,10 +434,18 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block,
#if !CK_WORKAROUND_SWDEV_231101 #if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_bf16x2( return __llvm_amdgcn_buffer_load_bf16x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
#else #else
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
return *reinterpret_cast<ushort2_t*>(&dst_out_tmp); return *reinterpret_cast<ushort2_t*>(&dst_out_tmp);
#endif #endif
...@@ -396,7 +454,8 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block, ...@@ -396,7 +454,8 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block,
template <> template <>
__device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block, __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<ushort> src_block_config; BufferAddressConfig<ushort> src_block_config;
...@@ -412,10 +471,18 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block, ...@@ -412,10 +471,18 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block,
#if !CK_WORKAROUND_SWDEV_231101 #if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_bf16x4( return __llvm_amdgcn_buffer_load_bf16x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
#else #else
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
return *reinterpret_cast<ushort4_t*>(&dst_out_tmp); return *reinterpret_cast<ushort4_t*>(&dst_out_tmp);
#endif #endif
...@@ -424,7 +491,8 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block, ...@@ -424,7 +491,8 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block,
template <> template <>
__device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_block, __device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset,
bool src_valid)
{ {
BufferAddressConfig<ushort> src_block_config; BufferAddressConfig<ushort> src_block_config;
...@@ -438,21 +506,22 @@ __device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_block, ...@@ -438,21 +506,22 @@ __device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_block,
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort); index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
#if !CK_WORKAROUND_SWDEV_231101
static_assert(false, "wrong! not implemented");
#else
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data,
0,
src_valid ? (src_thread_addr_offset + src_const_addr_offset) : -1,
false,
false);
return *reinterpret_cast<ushort8_t*>(&dst_out_tmp); return *reinterpret_cast<ushort8_t*>(&dst_out_tmp);
#endif
} }
template <> template <>
__device__ void amd_buffer_store<float, 1>(const float* p_src, __device__ void amd_buffer_store<float, 1>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<float> dst_block_config; BufferAddressConfig<float> dst_block_config;
...@@ -469,7 +538,8 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src, ...@@ -469,7 +538,8 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src,
__llvm_amdgcn_buffer_store_f32(*p_src, __llvm_amdgcn_buffer_store_f32(*p_src,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
} }
...@@ -478,7 +548,8 @@ template <> ...@@ -478,7 +548,8 @@ template <>
__device__ void amd_buffer_store<float, 2>(const float* p_src, __device__ void amd_buffer_store<float, 2>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<float> dst_block_config; BufferAddressConfig<float> dst_block_config;
...@@ -495,7 +566,8 @@ __device__ void amd_buffer_store<float, 2>(const float* p_src, ...@@ -495,7 +566,8 @@ __device__ void amd_buffer_store<float, 2>(const float* p_src,
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src), __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src),
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
} }
...@@ -504,7 +576,8 @@ template <> ...@@ -504,7 +576,8 @@ template <>
__device__ void amd_buffer_store<float, 4>(const float* p_src, __device__ void amd_buffer_store<float, 4>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<float> dst_block_config; BufferAddressConfig<float> dst_block_config;
...@@ -521,7 +594,8 @@ __device__ void amd_buffer_store<float, 4>(const float* p_src, ...@@ -521,7 +594,8 @@ __device__ void amd_buffer_store<float, 4>(const float* p_src,
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src), __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src),
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
} }
...@@ -530,7 +604,8 @@ template <> ...@@ -530,7 +604,8 @@ template <>
__device__ void amd_buffer_store<half_t, 1>(const half_t* p_src, __device__ void amd_buffer_store<half_t, 1>(const half_t* p_src,
half_t* p_dst_block, half_t* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<half_t> dst_block_config; BufferAddressConfig<half_t> dst_block_config;
...@@ -548,11 +623,15 @@ __device__ void amd_buffer_store<half_t, 1>(const half_t* p_src, ...@@ -548,11 +623,15 @@ __device__ void amd_buffer_store<half_t, 1>(const half_t* p_src,
__llvm_amdgcn_buffer_store_f16(*p_src, __llvm_amdgcn_buffer_store_f16(*p_src,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#else #else
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src; if(dst_valid)
{
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src;
}
#endif #endif
} }
...@@ -560,7 +639,8 @@ template <> ...@@ -560,7 +639,8 @@ template <>
__device__ void amd_buffer_store<half_t, 2>(const half_t* p_src, __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src,
half_t* p_dst_block, half_t* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<half_t> dst_block_config; BufferAddressConfig<half_t> dst_block_config;
...@@ -578,7 +658,8 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src, ...@@ -578,7 +658,8 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src,
__llvm_amdgcn_buffer_store_f16x2(*reinterpret_cast<const half2_t*>(p_src), __llvm_amdgcn_buffer_store_f16x2(*reinterpret_cast<const half2_t*>(p_src),
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#else #else
...@@ -587,7 +668,8 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src, ...@@ -587,7 +668,8 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src,
__llvm_amdgcn_buffer_store_f32(*p_src_tmp, __llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#endif #endif
...@@ -597,7 +679,8 @@ template <> ...@@ -597,7 +679,8 @@ template <>
__device__ void amd_buffer_store<half_t, 4>(const half_t* p_src, __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src,
half_t* p_dst_block, half_t* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
...@@ -615,7 +698,8 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src, ...@@ -615,7 +698,8 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src,
__llvm_amdgcn_buffer_store_f16x4(*reinterpret_cast<const half4_t*>(p_src), __llvm_amdgcn_buffer_store_f16x4(*reinterpret_cast<const half4_t*>(p_src),
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#else #else
...@@ -624,7 +708,8 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src, ...@@ -624,7 +708,8 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src,
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#endif #endif
...@@ -634,7 +719,8 @@ template <> ...@@ -634,7 +719,8 @@ template <>
__device__ void amd_buffer_store<ushort, 1>(const ushort* p_src, __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src,
ushort* p_dst_block, ushort* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<ushort> dst_block_config; BufferAddressConfig<ushort> dst_block_config;
...@@ -652,11 +738,15 @@ __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src, ...@@ -652,11 +738,15 @@ __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src,
__llvm_amdgcn_buffer_store_bf16(*p_src, __llvm_amdgcn_buffer_store_bf16(*p_src,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#else #else
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src; if(dst_valid)
{
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src;
}
#endif #endif
} }
...@@ -664,7 +754,8 @@ template <> ...@@ -664,7 +754,8 @@ template <>
__device__ void amd_buffer_store<ushort, 2>(const ushort* p_src, __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src,
ushort* p_dst_block, ushort* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<ushort> dst_block_config; BufferAddressConfig<ushort> dst_block_config;
...@@ -682,7 +773,8 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src, ...@@ -682,7 +773,8 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src,
__llvm_amdgcn_buffer_store_bf16x2(*p_src, __llvm_amdgcn_buffer_store_bf16x2(*p_src,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#else #else
...@@ -691,7 +783,8 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src, ...@@ -691,7 +783,8 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src,
__llvm_amdgcn_buffer_store_f32(*p_src_tmp, __llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#endif #endif
...@@ -701,7 +794,8 @@ template <> ...@@ -701,7 +794,8 @@ template <>
__device__ void amd_buffer_store<ushort, 4>(const ushort* p_src, __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src,
ushort* p_dst_block, ushort* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<ushort> dst_block_config; BufferAddressConfig<ushort> dst_block_config;
...@@ -719,7 +813,8 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src, ...@@ -719,7 +813,8 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src,
__llvm_amdgcn_buffer_store_bf16x4(*p_src, __llvm_amdgcn_buffer_store_bf16x4(*p_src,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#else #else
...@@ -728,7 +823,8 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src, ...@@ -728,7 +823,8 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src,
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false, false,
false); false);
#endif #endif
...@@ -738,7 +834,8 @@ template <> ...@@ -738,7 +834,8 @@ template <>
__device__ void amd_buffer_atomic_add<float, 1>(const float* p_src, __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<float> dst_block_config; BufferAddressConfig<float> dst_block_config;
...@@ -752,20 +849,41 @@ __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src, ...@@ -752,20 +849,41 @@ __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src,
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
__llvm_amdgcn_buffer_atomic_add_f32( __llvm_amdgcn_buffer_atomic_add_f32(*p_src,
*p_src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false); dst_block_config.data,
0,
dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset)
: -1,
false);
} }
template <> template <>
__device__ void amd_buffer_atomic_add<float, 2>(const float* p_src, __device__ void amd_buffer_atomic_add<float, 2>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
for(index_t i = 0; i < 2; ++i) for(index_t i = 0; i < 2; ++i)
{ {
amd_buffer_atomic_add<float, 1>( __llvm_amdgcn_buffer_atomic_add_f32(
&p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i); p_src[i],
dst_block_config.data,
0,
dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset + i * sizeof(float)) : -1,
false);
} }
} }
...@@ -773,12 +891,29 @@ template <> ...@@ -773,12 +891,29 @@ template <>
__device__ void amd_buffer_atomic_add<float, 4>(const float* p_src, __device__ void amd_buffer_atomic_add<float, 4>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset,
bool dst_valid)
{ {
BufferAddressConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
for(index_t i = 0; i < 4; ++i) for(index_t i = 0; i < 4; ++i)
{ {
amd_buffer_atomic_add<float, 1>( __llvm_amdgcn_buffer_atomic_add_f32(
&p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i); p_src[i],
dst_block_config.data,
0,
dst_valid ? (dst_thread_addr_offset + dst_const_addr_offset + i * sizeof(float)) : -1,
false);
} }
} }
......
...@@ -47,10 +47,25 @@ struct SetData ...@@ -47,10 +47,25 @@ struct SetData
// This version is only for compatibility, don't use this version if possible // This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace> template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const __device__ void Run(const T* p_src,
index_t src_offset,
bool src_valid,
T* p_dst,
index_t dst_offset,
bool dst_valid) const
{ {
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = if(dst_valid)
*reinterpret_cast<const vector_t*>(&p_src[src_offset]); {
if(src_valid)
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
}
else
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = 0;
}
}
} }
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
...@@ -61,11 +76,16 @@ struct SetData ...@@ -61,11 +76,16 @@ struct SetData
template <> template <>
__device__ void Run<AddressSpace::Global, AddressSpace::Vgpr>(const T* p_src, __device__ void Run<AddressSpace::Global, AddressSpace::Vgpr>(const T* p_src,
index_t src_offset, index_t src_offset,
bool src_valid,
T* p_dst, T* p_dst,
index_t dst_offset) const index_t dst_offset,
bool dst_valid) const
{ {
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = if(dst_valid)
amd_buffer_load<T, DataPerAccess>(p_src, src_offset, 0); {
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
amd_buffer_load<T, DataPerAccess>(p_src, src_offset, 0, src_valid);
}
} }
// buffer_store requires: // buffer_store requires:
...@@ -75,10 +95,15 @@ struct SetData ...@@ -75,10 +95,15 @@ struct SetData
template <> template <>
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src, __device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
index_t src_offset, index_t src_offset,
bool src_valid,
T* p_dst, T* p_dst,
index_t dst_offset) const index_t dst_offset,
bool dst_valid) const
{ {
amd_buffer_store<T, DataPerAccess>(&(p_src[src_offset]), p_dst, dst_offset, 0); const auto zeros = vector_t(0);
amd_buffer_store<T, DataPerAccess>(
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, 0, dst_valid);
} }
#endif #endif
}; };
...@@ -90,10 +115,18 @@ struct AtomicAddData ...@@ -90,10 +115,18 @@ struct AtomicAddData
// This version is only for compatibility, don't use this version if possible // This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace> template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const __device__ void Run(const T* p_src,
index_t src_offset,
bool src_valid,
T* p_dst,
index_t dst_offset,
bool dst_valid) const
{ {
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]), if(src_valid && dst_valid)
*reinterpret_cast<const vector_t*>(&p_src[src_offset])); {
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}
} }
#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_ADD #if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_ADD
...@@ -104,10 +137,14 @@ struct AtomicAddData ...@@ -104,10 +137,14 @@ struct AtomicAddData
template <> template <>
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src, __device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
index_t src_offset, index_t src_offset,
T* p_dst, bool src_valid T* p_dst,
index_t dst_offset) const index_t dst_offset,
bool dst_valid) const
{ {
amd_buffer_atomic_add<T, DataPerAccess>(&(p_src[src_offset]), p_dst, dst_offset, 0); const auto zeros = vector_t(0);
amd_buffer_atomic_add<T, DataPerAccess>(
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, 0, dst_valid);
} }
#endif #endif
}; };
...@@ -119,7 +156,12 @@ template <typename T, ...@@ -119,7 +156,12 @@ template <typename T,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
index_t SrcDataStride = 1, index_t SrcDataStride = 1,
index_t DstDataStride = 1> index_t DstDataStride = 1>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) __device__ void transfer_data(const T* p_src,
index_t src_offset,
bool src_valid,
T* p_dst,
index_t dst_offset,
bool dst_valid)
{ {
static_assert(DstInMemOp == InMemoryDataOperation::Set || static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::AtomicAdd, DstInMemOp == InMemoryDataOperation::AtomicAdd,
...@@ -131,27 +173,37 @@ __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, inde ...@@ -131,27 +173,37 @@ __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, inde
// TODO: use static_if::ElseIf // TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>( SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset); p_src, src_offset, src_valid, p_dst, dst_offset, dst_valid);
}); });
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>( AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset); p_src, src_offset, src_valid, p_dst, dst_offset, dst_valid);
}); });
} }
else else
{ {
for(index_t i = 0; i < DataPerAccess; i++) for(index_t i = 0; i < DataPerAccess; ++i)
{ {
// TODO: use static_if::ElseIf // TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>( SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); p_src,
src_offset + i * SrcDataStride,
src_valid,
p_dst,
dst_offset + i * DstDataStride,
dst_valid);
}); });
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>( AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); p_src,
src_offset + i * SrcDataStride,
src_valid,
p_dst,
dst_offset + i * DstDataStride,
dst_valid);
}); });
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment