Commit cf218184 authored by Chao Liu's avatar Chao Liu
Browse files

enable type conversion in blockwise copy v2 and threadwise copy v2r1

parent 012d3a07
......@@ -265,10 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.template Run<Float, address_space_t::global>(p_in_global,
p_in_block_double);
blockwise_wei_copy.template Run<Float, address_space_t::global>(p_wei_global,
p_wei_block_double);
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
p_in_global, p_in_block_double);
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
......@@ -299,9 +299,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
blockwise_in_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
blockwise_wei_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
......@@ -325,9 +327,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
......@@ -396,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
0,
b_thread_data_on_global,
0})
.template Run<Float, address_space_t::generic, address_space_t::global>(
.template Run<Float, Float, address_space_t::generic, address_space_t::global>(
p_out_thread, p_out_global);
}
}
......
......@@ -120,6 +120,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData,
BlockSrcAddressSpace,
address_space_t::generic>(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store
RunStoreThreadBuffer<BlockSrcData,
BlockDstData,
address_space_t::generic,
......
......@@ -478,35 +478,42 @@ struct BlockwiseGenericTensorSliceCopy_v2
return ThreadBufferDesc::GetElementSpace();
}
template <typename TData,
template <typename SrcData,
typename DstData,
address_space_t BlockSrcAddressSpace = address_space_t::generic,
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
__device__ void RunLoadThreadBuffer(const TData* p_block_src, TData* p_thread_buffer) const
__device__ void RunLoadThreadBuffer(const SrcData* p_block_src, DstData* p_thread_buffer) const
{
mThreadwiseLoad.template Run<TData, BlockSrcAddressSpace, ThreadBufferAddressSpace>(
mThreadwiseLoad
.template Run<SrcData, DstData, BlockSrcAddressSpace, ThreadBufferAddressSpace>(
p_block_src, p_thread_buffer);
}
template <typename TData,
template <typename SrcData,
typename DstData,
address_space_t ThreadBufferAddressSpace = address_space_t::generic,
address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void RunStoreThreadBuffer(const TData* p_thread_buffer, TData* p_block_dst) const
__device__ void RunStoreThreadBuffer(const SrcData* p_thread_buffer, DstData* p_block_dst) const
{
mThreadwiseStore.template Run<TData, ThreadBufferAddressSpace, BlockDstAddressSpace>(
mThreadwiseStore
.template Run<SrcData, DstData, ThreadBufferAddressSpace, BlockDstAddressSpace>(
p_thread_buffer, p_block_dst);
}
template <typename TData,
template <typename SrcData,
typename DstData,
address_space_t BlockSrcAddressSpace = address_space_t::generic,
address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void Run(const TData* p_block_src, TData* p_block_dst) const
__device__ void Run(const SrcData* p_block_src, DstData* p_block_dst) const
{
TData p_thread_buffer[GetThreadBufferSize()];
SrcData p_thread_buffer[GetThreadBufferSize()];
RunLoadThreadBuffer<TData, BlockSrcAddressSpace, address_space_t::generic>(p_block_src,
p_thread_buffer);
RunStoreThreadBuffer<TData, address_space_t::generic, BlockDstAddressSpace>(p_thread_buffer,
p_block_dst);
RunLoadThreadBuffer<SrcData, SrcData, BlockSrcAddressSpace, address_space_t::generic>(
p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store
RunStoreThreadBuffer<SrcData, DstData, address_space_t::generic, BlockDstAddressSpace>(
p_thread_buffer, p_block_dst);
}
template <typename T, bool PositiveDirection>
......
......@@ -537,19 +537,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
}
};
template <typename TData,
template <typename SrcData,
typename DstData,
address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run(const TData* p_src, TData* p_dst) const
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
TData p_buffer_[buffer_desc.GetElementSpace()];
TData* p_buffer = p_buffer_;
SrcData p_src_buffer_[buffer_desc.GetElementSpace()];
SrcData* p_src_buffer = p_src_buffer_;
// copy data from src into buffer
{
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
......@@ -573,8 +574,8 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr auto src_normal_dim_access_lengths =
src_access_lengths + Number<1>{} - src_merged_dim_access_lengths;
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}([&](
auto src_merged_dim_access_id) {
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}(
[&](auto src_merged_dim_access_id) {
auto src_merged_dim_data_id = src_merged_dim_access_id;
src_merged_dim_data_id(src_vector_access_dim) =
......@@ -614,7 +615,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
// 3. src_merged_offset can be runtime value (no assumption imposed)
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
vector_data = __buffer_load<TData, SrcDataPerAccess>(
vector_data = __buffer_load<SrcData, SrcDataPerAccess>(
p_src, src_merged_offset, src_normal_offset);
#else
vector_data = *reinterpret_cast<const src_vector_t*>(
......@@ -635,15 +636,26 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
p_src_buffer[buffer_offset] =
reinterpret_cast<const SrcData*>(&vector_data)[i];
}
});
});
}
// type conversion
// TODO: would compiler do a good job reusing register for buffer?
DstData p_dst_buffer_[buffer_desc.GetElementSpace()];
DstData* p_dst_buffer = p_dst_buffer_;
ford<SliceLengths>{}([&](auto idx) {
p_dst_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)] =
type_convert<DstData>{}(p_src_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)]);
});
// copy data from buffer into dst
{
using dst_vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
using dst_vector_t = typename vector_type<SrcData, DstDataPerAccess>::MemoryType;
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
......@@ -659,8 +671,8 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr auto dst_normal_dim_access_lengths =
dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths;
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}(
[&](auto dst_merged_dim_access_id) {
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_merged_dim_access_id) {
auto dst_merged_dim_data_id = dst_merged_dim_access_id;
dst_merged_dim_data_id(dst_vector_access_dim) =
......@@ -688,7 +700,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset];
reinterpret_cast<SrcData*>(&vector_data)[i] = p_dst_buffer[buffer_offset];
}
// offset w.r.t. normal dimension is known at compile-time
......@@ -712,7 +724,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
// 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
__buffer_store<TData, DstDataPerAccess>(
__buffer_store<SrcData, DstDataPerAccess>(
vector_data, p_dst, dst_merged_offset, dst_normal_offset);
#else
*reinterpret_cast<dst_vector_t*>(
......
......@@ -295,7 +295,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
......@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......@@ -438,7 +438,7 @@ int main(int argc, char* argv[])
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 0
#elif 1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment