"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0b7ade0d4cf1ecc475b9bc94b4a2a96ce093504b"
Commit cf218184 authored by Chao Liu's avatar Chao Liu
Browse files

enable type conversion in blockwise copy v2 and threadwise copy v2r1

parent 012d3a07
...@@ -265,10 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -265,10 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, address_space_t::global>(p_in_global, blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
p_in_block_double); p_in_global, p_in_block_double);
blockwise_wei_copy.template Run<Float, address_space_t::global>(p_wei_global, blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
p_wei_block_double); p_wei_global, p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -299,10 +299,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -299,10 +299,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( blockwise_in_copy
p_in_global, p_in_thread_buffer); .template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( p_in_global, p_in_thread_buffer);
p_wei_global, p_wei_thread_buffer); blockwise_wei_copy
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -325,9 +327,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -325,9 +327,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
...@@ -396,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -396,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
0, 0,
b_thread_data_on_global, b_thread_data_on_global,
0}) 0})
.template Run<Float, address_space_t::generic, address_space_t::global>( .template Run<Float, Float, address_space_t::generic, address_space_t::global>(
p_out_thread, p_out_global); p_out_thread, p_out_global);
} }
} }
......
...@@ -120,6 +120,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -120,6 +120,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData, BlockSrcData,
BlockSrcAddressSpace, BlockSrcAddressSpace,
address_space_t::generic>(p_block_src, p_thread_buffer); address_space_t::generic>(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store
RunStoreThreadBuffer<BlockSrcData, RunStoreThreadBuffer<BlockSrcData,
BlockDstData, BlockDstData,
address_space_t::generic, address_space_t::generic,
......
...@@ -478,35 +478,42 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -478,35 +478,42 @@ struct BlockwiseGenericTensorSliceCopy_v2
return ThreadBufferDesc::GetElementSpace(); return ThreadBufferDesc::GetElementSpace();
} }
template <typename TData, template <typename SrcData,
typename DstData,
address_space_t BlockSrcAddressSpace = address_space_t::generic, address_space_t BlockSrcAddressSpace = address_space_t::generic,
address_space_t ThreadBufferAddressSpace = address_space_t::generic> address_space_t ThreadBufferAddressSpace = address_space_t::generic>
__device__ void RunLoadThreadBuffer(const TData* p_block_src, TData* p_thread_buffer) const __device__ void RunLoadThreadBuffer(const SrcData* p_block_src, DstData* p_thread_buffer) const
{ {
mThreadwiseLoad.template Run<TData, BlockSrcAddressSpace, ThreadBufferAddressSpace>( mThreadwiseLoad
p_block_src, p_thread_buffer); .template Run<SrcData, DstData, BlockSrcAddressSpace, ThreadBufferAddressSpace>(
p_block_src, p_thread_buffer);
} }
template <typename TData, template <typename SrcData,
typename DstData,
address_space_t ThreadBufferAddressSpace = address_space_t::generic, address_space_t ThreadBufferAddressSpace = address_space_t::generic,
address_space_t BlockDstAddressSpace = address_space_t::generic> address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void RunStoreThreadBuffer(const TData* p_thread_buffer, TData* p_block_dst) const __device__ void RunStoreThreadBuffer(const SrcData* p_thread_buffer, DstData* p_block_dst) const
{ {
mThreadwiseStore.template Run<TData, ThreadBufferAddressSpace, BlockDstAddressSpace>( mThreadwiseStore
p_thread_buffer, p_block_dst); .template Run<SrcData, DstData, ThreadBufferAddressSpace, BlockDstAddressSpace>(
p_thread_buffer, p_block_dst);
} }
template <typename TData, template <typename SrcData,
typename DstData,
address_space_t BlockSrcAddressSpace = address_space_t::generic, address_space_t BlockSrcAddressSpace = address_space_t::generic,
address_space_t BlockDstAddressSpace = address_space_t::generic> address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void Run(const TData* p_block_src, TData* p_block_dst) const __device__ void Run(const SrcData* p_block_src, DstData* p_block_dst) const
{ {
TData p_thread_buffer[GetThreadBufferSize()]; SrcData p_thread_buffer[GetThreadBufferSize()];
RunLoadThreadBuffer<TData, BlockSrcAddressSpace, address_space_t::generic>(p_block_src, RunLoadThreadBuffer<SrcData, SrcData, BlockSrcAddressSpace, address_space_t::generic>(
p_thread_buffer); p_block_src, p_thread_buffer);
RunStoreThreadBuffer<TData, address_space_t::generic, BlockDstAddressSpace>(p_thread_buffer,
p_block_dst); // if there is type conversion, it's done during store
RunStoreThreadBuffer<SrcData, DstData, address_space_t::generic, BlockDstAddressSpace>(
p_thread_buffer, p_block_dst);
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
......
...@@ -537,19 +537,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1 ...@@ -537,19 +537,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
} }
}; };
template <typename TData, template <typename SrcData,
typename DstData,
address_space_t SrcAddressSpace = address_space_t::generic, address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic> address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run(const TData* p_src, TData* p_dst) const __device__ void Run(const SrcData* p_src, DstData* p_dst) const
{ {
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
TData p_buffer_[buffer_desc.GetElementSpace()]; SrcData p_src_buffer_[buffer_desc.GetElementSpace()];
TData* p_buffer = p_buffer_; SrcData* p_src_buffer = p_src_buffer_;
// copy data from src into buffer // copy data from src into buffer
{ {
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType; using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{}; constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
...@@ -573,77 +574,88 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1 ...@@ -573,77 +574,88 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr auto src_normal_dim_access_lengths = constexpr auto src_normal_dim_access_lengths =
src_access_lengths + Number<1>{} - src_merged_dim_access_lengths; src_access_lengths + Number<1>{} - src_merged_dim_access_lengths;
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}([&]( ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}(
auto src_merged_dim_access_id) { [&](auto src_merged_dim_access_id) {
auto src_merged_dim_data_id = src_merged_dim_access_id; auto src_merged_dim_data_id = src_merged_dim_access_id;
src_merged_dim_data_id(src_vector_access_dim) = src_merged_dim_data_id(src_vector_access_dim) =
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access; src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access;
// offset w.r.t. merged dimension need be computed at run-time, // offset w.r.t. merged dimension need be computed at run-time,
const index_t src_merged_offset = const index_t src_merged_offset =
(mSrcSliceOrigin + src_merged_dim_data_id).GetOffset(); (mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&]( ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
auto src_normal_dim_access_id) { auto src_normal_dim_access_id) {
auto src_normal_dim_data_id = src_normal_dim_access_id; auto src_normal_dim_data_id = src_normal_dim_access_id;
src_normal_dim_data_id(src_vector_access_dim) = src_normal_dim_data_id(src_vector_access_dim) =
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access; src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access;
// offset w.r.t. normal dimension is known at compile-time // offset w.r.t. normal dimension is known at compile-time
const index_t src_normal_offset = const index_t src_normal_offset =
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id); SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
src_vector_t vector_data; src_vector_t vector_data;
// Read vector from src. // Read vector from src.
// 1. Source code version can take src of all kinds of memory-space // 1. Source code version can take src of all kinds of memory-space
// 2. Intrinsic version using buffer_load can only take // 2. Intrinsic version using buffer_load can only take
// src from global-memory // src from global-memory
// //
// Commemt for loading from global-memory: // Commemt for loading from global-memory:
// When: // When:
// 1) using source code, in order for compiler to emit optimal // 1) using source code, in order for compiler to emit optimal
// load instruction, or // load instruction, or
// 2) using buffer_load intrinsic, in order for ISA to be valid, // 2) using buffer_load intrinsic, in order for ISA to be valid,
// following assumptions need to be satisfied: // following assumptions need to be satisfied:
// 1. p_src need to be block-invariant (assumption) // 1. p_src need to be block-invariant (assumption)
// 2. src_normal_offset must be calculatd at compile time (guaranteed by // 2. src_normal_offset must be calculatd at compile time (guaranteed by
// algorithm) // algorithm)
// 3. src_merged_offset can be runtime value (no assumption imposed) // 3. src_merged_offset can be runtime value (no assumption imposed)
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) { static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE #if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
vector_data = __buffer_load<TData, SrcDataPerAccess>( vector_data = __buffer_load<SrcData, SrcDataPerAccess>(
p_src, src_merged_offset, src_normal_offset); p_src, src_merged_offset, src_normal_offset);
#else #else
vector_data = *reinterpret_cast<const src_vector_t*>( vector_data = *reinterpret_cast<const src_vector_t*>(
&p_src[src_normal_offset + src_merged_offset]); &p_src[src_normal_offset + src_merged_offset]);
#endif #endif
}).Else([&](auto) { }).Else([&](auto) {
// src can be all kinds of memory-space. // src can be all kinds of memory-space.
vector_data = *reinterpret_cast<const src_vector_t*>( vector_data = *reinterpret_cast<const src_vector_t*>(
&p_src[src_normal_offset + src_merged_offset]); &p_src[src_normal_offset + src_merged_offset]);
}); });
// unpack vector into buffer // unpack vector into buffer
for(index_t i = 0; i < SrcDataPerAccess; ++i) for(index_t i = 0; i < SrcDataPerAccess; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(src_vector_access_dim) = i; scalar_id(src_vector_access_dim) = i;
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex( const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id); src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i]; p_src_buffer[buffer_offset] =
} reinterpret_cast<const SrcData*>(&vector_data)[i];
}
});
}); });
});
} }
// type conversion
// TODO: would compiler do a good job reusing register for buffer?
DstData p_dst_buffer_[buffer_desc.GetElementSpace()];
DstData* p_dst_buffer = p_dst_buffer_;
ford<SliceLengths>{}([&](auto idx) {
p_dst_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)] =
type_convert<DstData>{}(p_src_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)]);
});
// copy data from buffer into dst // copy data from buffer into dst
{ {
using dst_vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType; using dst_vector_t = typename vector_type<SrcData, DstDataPerAccess>::MemoryType;
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{}; constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
...@@ -659,72 +671,72 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1 ...@@ -659,72 +671,72 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr auto dst_normal_dim_access_lengths = constexpr auto dst_normal_dim_access_lengths =
dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths; dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths;
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}( ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
[&](auto dst_merged_dim_access_id) { auto dst_merged_dim_access_id) {
auto dst_merged_dim_data_id = dst_merged_dim_access_id; auto dst_merged_dim_data_id = dst_merged_dim_access_id;
dst_merged_dim_data_id(dst_vector_access_dim) = dst_merged_dim_data_id(dst_vector_access_dim) =
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access; dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
// offset w.r.t. merged dimension need be computed at run-time, // offset w.r.t. merged dimension need be computed at run-time,
const index_t dst_merged_offset = const index_t dst_merged_offset =
(mDstSliceOrigin + dst_merged_dim_data_id).GetOffset(); (mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&]( ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_normal_dim_access_id) { auto dst_normal_dim_access_id) {
auto dst_normal_dim_data_id = dst_normal_dim_access_id; auto dst_normal_dim_data_id = dst_normal_dim_access_id;
dst_normal_dim_data_id(dst_vector_access_dim) = dst_normal_dim_data_id(dst_vector_access_dim) =
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access; dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
dst_vector_t vector_data; dst_vector_t vector_data;
// pack vector from buffer // pack vector from buffer
for(index_t i = 0; i < DstDataPerAccess; ++i) for(index_t i = 0; i < DstDataPerAccess; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(dst_vector_access_dim) = i; scalar_id(dst_vector_access_dim) = i;
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex( const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id); dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset]; reinterpret_cast<SrcData*>(&vector_data)[i] = p_dst_buffer[buffer_offset];
} }
// offset w.r.t. normal dimension is known at compile-time // offset w.r.t. normal dimension is known at compile-time
const index_t dst_normal_offset = const index_t dst_normal_offset =
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id); DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
// Write vector into dst. // Write vector into dst.
// 1. Source code version can take dst of all kinds of memory-space // 1. Source code version can take dst of all kinds of memory-space
// 2. Intrinsic version using buffer_store can only take // 2. Intrinsic version using buffer_store can only take
// dst from global-memory // dst from global-memory
// //
// Commemt for storing into global-memory: // Commemt for storing into global-memory:
// When: // When:
// 1) using source code, in order for compiler to emit optimal // 1) using source code, in order for compiler to emit optimal
// store instruction, or // store instruction, or
// 2) using buffer_store, intrinsic in order ISA to be valid // 2) using buffer_store, intrinsic in order ISA to be valid
// following assumptions need to be satisfied: // following assumptions need to be satisfied:
// 1. p_dst need to be block-invariant (assumption) // 1. p_dst need to be block-invariant (assumption)
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by // 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// algorithm) // algorithm)
// 3. dst_merged_offset can be runtime value (no assumption imposed) // 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) { static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE #if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
__buffer_store<TData, DstDataPerAccess>( __buffer_store<SrcData, DstDataPerAccess>(
vector_data, p_dst, dst_merged_offset, dst_normal_offset); vector_data, p_dst, dst_merged_offset, dst_normal_offset);
#else #else
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data; &p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
#endif #endif
}).Else([&](auto) { }).Else([&](auto) {
// dst can be all kinds of memory-space // dst can be all kinds of memory-space
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data; &p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
});
}); });
}); });
});
} }
} }
......
...@@ -295,7 +295,7 @@ int main(int argc, char* argv[]) ...@@ -295,7 +295,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -341,7 +341,7 @@ int main(int argc, char* argv[]) ...@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -438,7 +438,7 @@ int main(int argc, char* argv[]) ...@@ -438,7 +438,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment