Commit b3d4595f authored by Chao Liu's avatar Chao Liu
Browse files

added type conversion in threadwise and blockwise copy

parent 3cb2a7d0
......@@ -287,9 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
p_in_global, p_in_block_double);
blockwise_wei_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
p_wei_global, p_wei_block_double);
}
......@@ -312,8 +312,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
......@@ -321,25 +321,27 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_in_global, p_in_register_buffer);
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_register_buffer);
blockwise_in_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
// even iteration
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
......@@ -347,19 +349,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_in_global, p_in_register_buffer);
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_register_buffer);
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
......@@ -431,9 +433,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
b_thread_data_on_global,
0})
#if 1
.template Run_generic<Float, Float, address_space_t::generic, address_space_t::global>
.template Run_generic<Float,
Float,
address_space_t::generic,
address_space_t::global>
#elif 1
.template Run_optimized_dst_address_calculation<Float, Float, address_space_t::global>
.template Run_optimized_dst_address_calculation<Float,
Float,
address_space_t::global>
#endif
(p_out_thread, p_out_global);
}
......
......@@ -678,10 +678,10 @@ struct BlockwiseGenericTensorSliceCopy_v3
};
template <index_t BlockSize,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SubLengths,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
......@@ -692,24 +692,49 @@ template <index_t BlockSize,
index_t DstDataPerAccess>
struct BlockwiseGenericTensorSliceCopy_v4
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseGenericTensorSliceCopy_v4(const Index& src_block_slice_origin,
const Index& dst_block_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
nDim == SubLengths::Size() && nDim == ThreadClusterLengths::Size() &&
static_assert(nDim == BlockSrcDesc::GetNumOfDimension() &&
nDim == BlockDstDesc::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
#if 1
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
#else
constexpr auto thread_cluster_lengths_in_arrange_order =
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{});
constexpr auto thread_cluster_desc = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(thread_cluster_lengths_in_arrange_order),
make_tuple(Merge<decltype(thread_cluster_lengths_in_arrange_order)>{}),
make_tuple(arithmetic)
::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
constexpr auto thread_cluster_id = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(Sequence<KBlockWork, BBlockWork>{}),
make_tuple(Merge<Sequence<KBlockWork, BBlockWork>>{}),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
const auto block_work_multi_id = block_work_desc.CalculateLowerIndex(get_block_1d_id());
#endif
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
......@@ -720,7 +745,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
const auto data_cluster_id =
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
const auto thread_data_id_begin = data_cluster_id * ThreadSliceLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
......@@ -729,51 +754,70 @@ struct BlockwiseGenericTensorSliceCopy_v4
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
__device__ static constexpr index_t GetRegisterBufferSize()
__device__ static constexpr index_t GetThreadBufferSize()
{
return RegisterBufferDesc::GetElementSpace();
return ThreadBufferDesc::GetElementSpace();
}
template <typename SrcData, typename BufferData, address_space_t SrcAddressSpace = address_space_t::generic>
__device__ void RunLoadRegisterBuffer(const SrcData* p_src, BufferData* p_buffer) const
template <typename BlockSrcData,
typename ThreadBufferData,
address_space_t BlockSrcAddressSpace = address_space_t::generic,
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBufferData* p_thread_buffer) const
{
#if 1
mThreadwiseLoad.template Run_generic<SrcData, BufferData, SrcAddressSpace, address_space_t::generic>(
p_src, p_buffer);
mThreadwiseLoad.template Run_generic<BlockSrcData,
ThreadBufferData,
BlockSrcAddressSpace,
ThreadBufferAddressSpace>(p_block_src,
p_thread_buffer);
#else
mThreadwiseLoad.template Run_optimized_src_address_calculation<SrcData,
BufferData,
SrcAddressSpace,
address_space_t::generic>(
p_src, p_buffer);
mThreadwiseLoad.template Run_optimized_src_address_calculation<BlockSrcData,
ThreadBufferData,
BlockSrcAddressSpace,
ThreadBufferAddressSpace>(
p_block_src, p_thread_buffer);
#endif
}
template <typename BufferData, typename DstData, address_space_t DstAddressSpace = address_space_t::generic>
__device__ void RunStoreRegisterBuffer(const BufferData* p_buffer, DstData* p_dst) const
template <typename ThreadBufferData,
typename BlockDstData,
address_space_t ThreadBufferAddressSpace = address_space_t::generic,
address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
BlockDstData* p_block_dst) const
{
#if 1
mThreadwiseStore.template Run_generic<BufferData, DstData, address_space_t::generic, DstAddressSpace>(
p_buffer, p_dst);
mThreadwiseStore.template Run_generic<ThreadBufferData,
BlockDstData,
ThreadBufferAddressSpace,
BlockDstAddressSpace>(p_thread_buffer, p_block_dst);
#else
mThreadwiseStore.template Run_optimized_dst_address_calculation<BufferData,
DstData,
address_space_t::generic,
DstAddressSpace>(p_buffer,
p_dst);
mThreadwiseStore.template Run_optimized_dst_address_calculation<ThreadBufferData,
BlockDstData,
ThreadBufferAddressSpace,
BlockDstAddressSpace>(
p_thread_buffer, p_block_dst);
#endif
}
template <typename SrcData,
typename DstData,
address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
template <typename BlockSrcData,
typename BlockDstData,
address_space_t BlockSrcAddressSpace = address_space_t::generic,
address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{
SrcData p_src_buffer[GetRegisterBufferSize()];
RunLoadRegisterBuffer<SrcData, SrcData, SrcAddressSpace>(p_src, p_buffer);
RunStoreRegisterBuffer<SrcData, DstData, DstAddressSpace>(p_buffer, p_dst);
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
RunLoadThreadBuffer<BlockSrcData,
BlockSrcData,
BlockSrcAddressSpace,
address_space_t::generic>(p_block_src, p_thread_buffer);
RunStoreThreadBuffer<BlockSrcData,
BlockDstData,
address_space_t::generic,
BlockDstAddressSpace>(p_thread_buffer, p_block_dst);
}
template <typename T, bool PositiveDirection>
......@@ -793,19 +837,19 @@ struct BlockwiseGenericTensorSliceCopy_v4
}
private:
using RegisterBufferDesc = decltype(make_native_tensor_descriptor_packed(SubLengths{}));
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc,
RegisterBufferDesc,
SubLengths,
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<BlockSrcDesc,
ThreadBufferDesc,
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcDataPerAccess,
1>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc,
DstDesc,
SubLengths,
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorAccessDim,
1,
......
......@@ -1180,7 +1180,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
template <typename SrcData,
typename DstData,
typename DstData,
address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run_generic(const SrcData* p_src, DstData* p_dst) const
......@@ -1233,7 +1233,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
__buffer_load<SrcData, SrcDataPerAccess>(p_src, src_coord.GetOffset(), 0);
__buffer_load<SrcData, SrcDataPerAccess>(
p_src, src_coord.GetOffset(), 0);
#else
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
......@@ -1246,12 +1247,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
}
}
// SrcData to DstData conversion
// SrcData to DstData conversion
DstData p_dst_long_vector[long_vector_size];
for(index_t i = 0; i < long_vector_size; ++i)
for(index_t i = 0; i < long_vector_size; ++i)
{
p_dst_long_vector[i] = type_convert<DstData>(p_src_long_vector[i]);
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
}
// store data from the long-vector buffer to dst
......
......@@ -38,11 +38,11 @@ typedef float float4_t __attribute__((ext_vector_type(4)));
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
// data type conversion
template <class T>
template <typename T>
struct type_convert
{
template <class X>
__device__ T operator()(X x) const
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment