Commit 8bdaba51 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent fab2f10a
...@@ -241,16 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -241,16 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
#else #else
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
p_in_register_clipboard); blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_buffer);
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block); blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
#endif #endif
__syncthreads(); __syncthreads();
......
...@@ -4,11 +4,8 @@ ...@@ -4,11 +4,8 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
namespace ck { namespace ck {
...@@ -37,10 +34,13 @@ template <index_t GridSize, ...@@ -37,10 +34,13 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
class InBlockCopySubLengths_CHWN,
class InBlockCopyClusterLengths_CHWN, class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N, index_t InBlockCopyDataPerAccess_N,
index_t WeiBlockCopyDataPerRead_K, class WeiBlockCopySubLengths_CK,
index_t OutThreadCopyDataPerWrite_N> class WeiBlockCopyClusterLengths_CK,
index_t WeiBlockCopyDataPerAccess_K,
index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
...@@ -103,8 +103,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -103,8 +103,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N, constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
...@@ -123,24 +123,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -123,24 +123,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
#elif 0
using InBlockCopySubLengths_CHWN =
decltype(in_c_h_w_n_block_desc.GetLengths() / InBlockCopyClusterLengths_CHWN{});
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
...@@ -149,33 +135,28 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -149,33 +135,28 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
1, 3,
1>({0, 0, 0, 0}, {0, 0, 0, 0}); 3,
#elif 1 InBlockCopyDataPerAccess_N,
using InBlockCopySubLengths_CHWN = InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
decltype(in_c_h_w_n_block_desc.GetLengths() / InBlockCopyClusterLengths_CHWN{}); {0, 0, 0, 0});
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
NormalTensorCoordinate<decltype(in_c_h_w_n_global_desc)>,
NormalTensorCoordinate<decltype(in_c_h_w_n_block_desc)>,
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
Sequence<0, 1, 2, 3>>({0, 0, 0, 0}, {0, 0, 0, 0});
#endif
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
Float, decltype(wei_c_k_global_desc),
decltype(wei_c_k_global_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc.GetLengths()),
decltype(wei_c_k_block_desc.GetLengths()), WeiBlockCopySubLengths_CK,
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0}); WeiBlockCopyClusterLengths_CK,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -278,7 +259,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -278,7 +259,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
} }
// output: register to global mem, // output: register to global mem
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...@@ -329,17 +310,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -329,17 +310,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy(out_10d_thread_desc, Float* p_out_thread_on_global = p_out_global +
p_out_thread, out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_10d_global_desc, k_block_data_begin + k_thread_data_begin,
p_out_global + ho_block_data_begin + ho_thread_data_begin,
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( wo_block_data_begin + wo_thread_data_begin,
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin);
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc),
n_block_data_begin + n_thread_data_begin), decltype(out_10d_global_desc),
out_10d_thread_desc.GetLengths(), decltype(out_10d_thread_desc.GetLengths()),
Number<OutThreadCopyDataPerWrite_N>{}); arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -380,17 +368,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -380,17 +368,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy(out_10d_thread_desc, Float* p_out_thread_on_global = p_out_global +
p_out_thread, out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_10d_global_desc, k_block_data_begin + k_thread_data_begin,
p_out_global + ho_block_data_begin + ho_thread_data_begin,
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( wo_block_data_begin + wo_thread_data_begin,
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin);
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc),
n_block_data_begin + n_thread_data_begin), decltype(out_10d_global_desc),
out_10d_thread_desc.GetLengths(), decltype(out_10d_thread_desc.GetLengths()),
Number<OutThreadCopyDataPerWrite_N>{}); arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
}); });
} }
}; };
......
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
namespace ck { namespace ck {
...@@ -36,10 +34,13 @@ template <index_t GridSize, ...@@ -36,10 +34,13 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
class InBlockCopySubLengths_CHWN,
class InBlockCopyClusterLengths_CHWN, class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N, index_t InBlockCopyDataPerAccess_N,
index_t WeiBlockCopyDataPerRead_K, class WeiBlockCopySubLengths_CK,
index_t OutThreadCopyDataPerWrite_N> class WeiBlockCopyClusterLengths_CK,
index_t WeiBlockCopyDataPerAccess_K,
index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
...@@ -108,8 +109,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -108,8 +109,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N, constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
...@@ -130,24 +131,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -130,24 +131,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy = auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
Float, decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc.GetLengths()),
decltype(in_c_h_w_n_block_desc.GetLengths()), InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN, InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{}; Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
{0, 0, 0, 0});
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
Float, decltype(wei_c_k_global_desc),
decltype(wei_c_k_global_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc.GetLengths()),
decltype(wei_c_k_block_desc.GetLengths()), WeiBlockCopySubLengths_CK,
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0}); WeiBlockCopyClusterLengths_CK,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -233,18 +248,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -233,18 +248,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double); p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double); p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -266,9 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -266,9 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
p_in_global_block_offset += p_in_global_block_offset +=
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0); CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
...@@ -278,25 +292,25 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -278,25 +292,25 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_next); p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_next); p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration // even iteration
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0); p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
...@@ -305,19 +319,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -305,19 +319,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard( blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_register_clipboard, p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -330,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -330,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
} }
} }
// output: register to global mem, // output: register to global mem
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...@@ -381,17 +395,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -381,17 +395,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
} }
#endif #endif
threadwise_tensor_slice_copy(out_10d_thread_desc, Float* p_out_thread_on_global = p_out_global +
p_out_thread, out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_10d_global_desc, k_block_data_begin + k_thread_data_begin,
p_out_global + ho_block_data_begin + ho_thread_data_begin,
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( wo_block_data_begin + wo_thread_data_begin,
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin);
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc),
n_block_data_begin + n_thread_data_begin), decltype(out_10d_global_desc),
out_10d_thread_desc.GetLengths(), decltype(out_10d_thread_desc.GetLengths()),
Number<OutThreadCopyDataPerWrite_N>{}); arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -432,17 +453,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -432,17 +453,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
} }
#endif #endif
threadwise_tensor_slice_copy(out_10d_thread_desc, Float* p_out_thread_on_global = p_out_global +
p_out_thread, out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_10d_global_desc, k_block_data_begin + k_thread_data_begin,
p_out_global + ho_block_data_begin + ho_thread_data_begin,
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( wo_block_data_begin + wo_thread_data_begin,
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin);
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc),
n_block_data_begin + n_thread_data_begin), decltype(out_10d_global_desc),
out_10d_thread_desc.GetLengths(), decltype(out_10d_thread_desc.GetLengths()),
Number<OutThreadCopyDataPerWrite_N>{}); arithmetic_sequence_gen<0, 10, 1>::type,
arithmetic_sequence_gen<0, 10, 1>::type,
9,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
}); });
} }
}; };
......
...@@ -254,19 +254,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -254,19 +254,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
Float p_in_register_clipboard[blockwise_in_copy_reorder Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, p_in_register_buffer);
p_in_register_clipboard); blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_buffer);
p_wei_register_clipboard);
blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer,
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
p_in_block_double); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_double);
p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -288,10 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -288,10 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float Float
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_global_block_offset += p_in_global_block_offset +=
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
...@@ -301,27 +299,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -301,27 +299,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread); run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_next); p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_next); p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy_reorder Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
// even iteration // even iteration
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
...@@ -330,19 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -330,19 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread); run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard( blockwise_in_copy_reorder.RunStoreRegisterBuffer(
p_in_register_clipboard, p_in_block_double + in_block_space); p_in_register_buffer, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard( blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_register_clipboard, p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
......
...@@ -214,16 +214,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -214,16 +214,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
__syncthreads()) __syncthreads())
{ {
// load data // load data
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
p_in_register_clipboard); blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_buffer);
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block); blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
__syncthreads(); __syncthreads();
......
...@@ -209,17 +209,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -209,17 +209,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
// preload data into LDS // preload data into LDS
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
p_in_register_clipboard); blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_buffer);
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double); blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_double);
p_wei_block_double);
} }
// register // register
...@@ -247,18 +245,18 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -247,18 +245,18 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
// load next data // load next data
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
__syncthreads(); __syncthreads();
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
// compute on current data // compute on current data
// a series of GEMM // a series of GEMM
...@@ -280,10 +278,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -280,10 +278,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
} }
} }
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
p_in_block_next); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
} }
} }
...@@ -295,14 +291,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -295,14 +291,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
__syncthreads(); __syncthreads();
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_buffer);
for(index_t y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
...@@ -322,10 +317,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -322,10 +317,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
} }
} }
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd // odd
__syncthreads(); __syncthreads();
......
...@@ -267,9 +267,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -267,9 +267,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
...@@ -277,26 +276,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -277,26 +276,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global, blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_next); p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_next); p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration // even iteration
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
...@@ -305,19 +304,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -305,19 +304,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global, blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global,
p_in_register_clipboard); p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard( blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_register_clipboard, p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
......
...@@ -176,22 +176,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -176,22 +176,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
#else #else
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_in_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
MergedTensorCoordinate<decltype(in_e_n1_b_n2_global_merged_desc)>, decltype(in_e_n1_b_n2_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(in_e_n1_b_n2_block_desc)>, InBlockCopySubLengths_E_N1_B_N2,
decltype(in_e_n1_b_n2_block_desc.GetLengths()), InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopySubLengths_E_N1_B_N2, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopySrcAccessOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstAccessOrder,
InBlockCopySrcAccessOrder, 2,
InBlockCopyDstAccessOrder, 3,
2, InBlockCopySrcDataPerRead_B,
3, InBlockCopyDstDataPerWrite_N2>(
InBlockCopySrcDataPerRead_B, {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
#endif #endif
// weight tensor // weight tensor
...@@ -225,22 +224,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -225,22 +224,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
#else #else
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, decltype(wei_e_k_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>, WeiBlockCopySubLengths_E_K,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, 0,
WeiBlockCopyDstAccessOrder, 1,
0, WeiBlockCopySrcDataPerRead_E,
1, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopySrcDataPerRead_E, {0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
#endif #endif
// GEMM definition // GEMM definition
...@@ -448,8 +446,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -448,8 +446,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
ThreadwiseGenericTensorSliceCopy_v2r1< ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
NormalTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc)>,
MergedTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc)>,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type, arithmetic_sequence_gen<0, 8, 1>::type,
arithmetic_sequence_gen<0, 8, 1>::type, arithmetic_sequence_gen<0, 8, 1>::type,
......
...@@ -313,8 +313,8 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -313,8 +313,8 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
...@@ -322,25 +322,23 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -322,25 +322,23 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
p_in_block_next); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
...@@ -349,18 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -349,18 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
......
...@@ -319,8 +319,8 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -319,8 +319,8 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
...@@ -328,9 +328,9 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -328,9 +328,9 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_clipboard); p_wei_register_buffer);
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
...@@ -338,10 +338,10 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -338,10 +338,10 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
printf("tid (%d %d), %f %f %f %f\n", printf("tid (%d %d), %f %f %f %f\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
p_wei_register_clipboard[0], p_wei_register_buffer[0],
p_wei_register_clipboard[1], p_wei_register_buffer[1],
p_wei_register_clipboard[2], p_wei_register_buffer[2],
p_wei_register_clipboard[3]); p_wei_register_buffer[3]);
} }
#endif #endif
...@@ -349,17 +349,15 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -349,17 +349,15 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
p_in_block_next); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
...@@ -368,18 +366,17 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -368,18 +366,17 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
......
...@@ -134,8 +134,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -134,8 +134,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B, InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B, InBlockCopyClusterLengths_E_B,
...@@ -162,22 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -162,22 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, decltype(wei_e_k_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>, WeiBlockCopySubLengths_E_K,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, 0,
WeiBlockCopyDstAccessOrder, 1,
0, WeiBlockCopySrcDataPerRead_E,
1, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopySrcDataPerRead_E, {0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -365,21 +362,20 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -365,21 +362,20 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
using OutThreadCopySliceLengths = using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>; Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy =
decltype(out_k0_k1_b_thread_desc), ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc), decltype(out_k0_k1_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>, OutThreadCopySliceLengths,
MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>, arithmetic_sequence_gen<0, 3, 1>::type,
OutThreadCopySliceLengths, arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type, 2,
arithmetic_sequence_gen<0, 3, 1>::type, 2,
2, OutThreadCopyDataPerAccess_B,
2, OutThreadCopyDataPerAccess_B>(
OutThreadCopyDataPerAccess_B, {0, 0, 0},
OutThreadCopyDataPerAccess_B>({0, 0, 0}, {k_thread_data_on_global / K1,
{k_thread_data_on_global / K1, k_thread_data_on_global % K1,
k_thread_data_on_global % K1, b_thread_data_on_global});
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{ {
......
...@@ -295,5 +295,27 @@ struct MergedTensorCoordinate ...@@ -295,5 +295,27 @@ struct MergedTensorCoordinate
index_t mOffset; index_t mOffset;
}; };
template <class TensorDesc>
struct TensorCoordinate
{
private:
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
{
return NormalTensorCoordinate<ConstantTensorDescriptor<Ts...>>();
}
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
{
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -563,7 +563,7 @@ struct Blockwise2dTensorCopy3 ...@@ -563,7 +563,7 @@ struct Blockwise2dTensorCopy3
} }
} }
__device__ constexpr index_t GetRegisterClipboardSize() const __device__ constexpr index_t GetRegisterBufferSize() const
{ {
static_assert(is_same<Float, float>{}, "wrong! only support float!\n"); static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
...@@ -579,8 +579,8 @@ struct Blockwise2dTensorCopy3 ...@@ -579,8 +579,8 @@ struct Blockwise2dTensorCopy3
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0; return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
} }
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, __device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const Float* __restrict__ p_clipboard) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -630,8 +630,8 @@ struct Blockwise2dTensorCopy3 ...@@ -630,8 +630,8 @@ struct Blockwise2dTensorCopy3
} }
} }
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, __device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const Float* __restrict__ p_dst) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -681,8 +681,8 @@ struct Blockwise2dTensorCopy3 ...@@ -681,8 +681,8 @@ struct Blockwise2dTensorCopy3
} }
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
__device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src, __device__ void RunLoadRegisterBuffer_asm(const Float* __restrict__ p_src,
Float* p_clipboard) const Float* p_clipboard) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -741,8 +741,8 @@ struct Blockwise2dTensorCopy3 ...@@ -741,8 +741,8 @@ struct Blockwise2dTensorCopy3
} }
} }
__device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard, __device__ void RunStoreRegisterBuffer_asm(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const Float* __restrict__ p_dst) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -237,7 +237,7 @@ struct Blockwise3dTensorCopy3 ...@@ -237,7 +237,7 @@ struct Blockwise3dTensorCopy3
} }
} }
__device__ static constexpr index_t GetRegisterClipboardSize() __device__ static constexpr index_t GetRegisterBufferSize()
{ {
static_assert(is_same<Float, float>{}, "wrong! only support float!\n"); static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
...@@ -260,8 +260,8 @@ struct Blockwise3dTensorCopy3 ...@@ -260,8 +260,8 @@ struct Blockwise3dTensorCopy3
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2; return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2;
} }
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, __device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const Float* __restrict__ p_clipboard) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -316,8 +316,8 @@ struct Blockwise3dTensorCopy3 ...@@ -316,8 +316,8 @@ struct Blockwise3dTensorCopy3
} }
} }
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, __device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const Float* __restrict__ p_dst) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -596,7 +596,7 @@ struct Blockwise4dTensorCopy3 ...@@ -596,7 +596,7 @@ struct Blockwise4dTensorCopy3
} }
} }
__device__ constexpr index_t GetRegisterClipboardSize() const __device__ constexpr index_t GetRegisterBufferSize() const
{ {
static_assert(is_same<Float, float>{}, "wrong! only support float!\n"); static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
...@@ -623,8 +623,8 @@ struct Blockwise4dTensorCopy3 ...@@ -623,8 +623,8 @@ struct Blockwise4dTensorCopy3
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3; return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3;
} }
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, __device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const Float* __restrict__ p_clipboard) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -690,8 +690,8 @@ struct Blockwise4dTensorCopy3 ...@@ -690,8 +690,8 @@ struct Blockwise4dTensorCopy3
} }
} }
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, __device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const Float* __restrict__ p_dst) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -420,8 +420,6 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -420,8 +420,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
template <index_t BlockSize, template <index_t BlockSize,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class SrcCoordinate,
class DstCoordinate,
class SliceLengths, class SliceLengths,
class SubLengths, class SubLengths,
class ThreadClusterLengths, class ThreadClusterLengths,
...@@ -436,6 +434,9 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -436,6 +434,9 @@ struct BlockwiseGenericTensorSliceCopy_v2
{ {
static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
using SrcCoordinate = typename TensorCoordinate<SrcDesc>::type;
using DstCoordinate = typename TensorCoordinate<DstDesc>::type;
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin, __device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin,
DstCoordinate dst_block_slice_origin) DstCoordinate dst_block_slice_origin)
{ {
...@@ -515,31 +516,25 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -515,31 +516,25 @@ struct BlockwiseGenericTensorSliceCopy_v2
private: private:
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{})); using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
using ThreadwiseLoad = using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc, RegisterBufferDesc,
RegisterBufferDesc, SubLengths,
SrcCoordinate, SrcDimAccessOrder,
NormalTensorCoordinate<RegisterBufferDesc>, SrcDimAccessOrder,
SubLengths, SrcVectorAccessDim,
SrcDimAccessOrder, SrcVectorAccessDim,
SrcDimAccessOrder, SrcDataPerAccess,
SrcVectorAccessDim, 1>;
SrcVectorAccessDim,
SrcDataPerAccess, using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v2r1<RegisterBufferDesc,
1>; DstDesc,
SubLengths,
using ThreadwiseStore = DstDimAccessOrder,
ThreadwiseGenericTensorSliceCopy_v2r1<RegisterBufferDesc, DstDimAccessOrder,
DstDesc, DstVectorAccessDim,
NormalTensorCoordinate<RegisterBufferDesc>, DstVectorAccessDim,
DstCoordinate, 1,
SubLengths, DstDataPerAccess>;
DstDimAccessOrder,
DstDimAccessOrder,
DstVectorAccessDim,
DstVectorAccessDim,
1,
DstDataPerAccess>;
ThreadwiseLoad mThreadwiseLoad; ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore; ThreadwiseStore mThreadwiseStore;
......
...@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
#endif #endif
} }
__device__ static constexpr index_t GetRegisterClipboardSize() __device__ static constexpr index_t GetRegisterBufferSize()
{ {
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
...@@ -183,8 +183,8 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -183,8 +183,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
return thread_tensor_desc.GetElementSpace(); return thread_tensor_desc.GetElementSpace();
} }
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, __device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const Float* __restrict__ p_clipboard) const
{ {
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
...@@ -219,8 +219,8 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -219,8 +219,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
}); });
} }
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, __device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const Float* __restrict__ p_dst) const
{ {
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
...@@ -274,10 +274,10 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -274,10 +274,10 @@ struct BlockwiseTensorSliceReorderCopy_v3
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{ {
Float p_clipboard[GetRegisterClipboardSize()]; Float p_clipboard[GetRegisterBufferSize()];
RunLoadRegisterClipboard(p_src, p_clipboard); RunLoadRegisterBuffer(p_src, p_clipboard);
RunStoreRegisterClipboard(p_clipboard, p_dst); RunStoreRegisterBuffer(p_clipboard, p_dst);
} }
// this function doesn't do santiy check on whether the slicing window is out of the boundary // this function doesn't do santiy check on whether the slicing window is out of the boundary
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 #ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#endif #endif
...@@ -430,170 +426,6 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2 ...@@ -430,170 +426,6 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
Array<index_t, nDim> mDstSliceOrigin; Array<index_t, nDim> mDstSliceOrigin;
}; };
template <class SrcDesc,
class DstDesc,
class SrcCoordinate,
class DstCoordinate,
class SliceLengths>
struct ThreadwiseGenericTensorSliceCopy_v2
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin,
DstCoordinate dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
: ThreadwiseGenericTensorSliceCopy_v2(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <class TDesc, class Seq>
struct IsolateMergedDimSliceLengthsHack
{
template <class IDim>
__device__ constexpr index_t operator()(IDim idim) const
{
return TDesc::ContainMultipleOriginalDimensions(idim) ? Seq{}[idim] : 1;
}
};
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
TData p_buffer_[buffer_desc.GetElementSpace()];
TData* p_buffer = p_buffer_;
// hacks to isolate merged dimension from normal dimensions, and calculate their offset
// seperately
// SrcMergedDimSliceLengthsHack has entry same as SliceLengths on src merged dimensions,
// but 1 on normal dimensions;
// SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions,
// but 1 on merged dimensions;
using SrcMergedDimSliceLengthsHack =
typename sequence_gen<SliceLengths::GetSize(),
IsolateMergedDimSliceLengthsHack<SrcDesc, SliceLengths>>::type;
using SrcNormalDimSliceLengthsHack =
decltype((SliceLengths{} + Number<1>{}) - SrcMergedDimSliceLengthsHack{});
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2
static_ford<SrcMergedDimSliceLengthsHack>{}([&](auto merged_dim_data_id_) {
constexpr auto merged_dim_data_id = decltype(merged_dim_data_id_){};
const TData* p_src_tmp = p_src + (mSrcSliceOrigin + merged_dim_data_id).GetOffset();
static_ford<SrcNormalDimSliceLengthsHack>{}([&](auto normal_dim_data_id_) {
constexpr auto normal_dim_data_id = decltype(normal_dim_data_id_){};
constexpr index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(merged_dim_data_id + normal_dim_data_id);
constexpr index_t src_normal_offset =
SrcDesc::GetOffsetFromMultiIndex(normal_dim_data_id);
p_buffer[buffer_offset] = p_src_tmp[src_normal_offset];
});
});
#else
ford<SrcMergedDimSliceLengthsHack>{}([&](auto merged_dim_data_id) {
const TData* p_src_tmp = p_src + (mSrcSliceOrigin + merged_dim_data_id).GetOffset();
ford<SrcNormalDimSliceLengthsHack>{}([&](auto normal_dim_data_id) {
const index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(merged_dim_data_id + normal_dim_data_id);
const index_t src_normal_offset =
SrcDesc::GetOffsetFromMultiIndex(normal_dim_data_id);
p_buffer[buffer_offset] = p_src_tmp[src_normal_offset];
});
});
#endif
// DstMergedDimSliceLengthsHack has entry same as SliceLengths on dst merged dimensions,
// but 1 on normal dimensions;
// DstNormalDimSliceLengthsHack has entry same as SliceLengths on dst normal dimensions,
// but 1 on merged dimensions;
using DstMergedDimSliceLengthsHack =
typename sequence_gen<SliceLengths::GetSize(),
IsolateMergedDimSliceLengthsHack<DstDesc, SliceLengths>>::type;
using DstNormalDimSliceLengthsHack =
decltype((SliceLengths{} + Number<1>{}) - DstMergedDimSliceLengthsHack{});
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2
static_ford<DstMergedDimSliceLengthsHack>{}([&](auto merged_dim_data_id_) {
constexpr auto merged_dim_data_id = decltype(merged_dim_data_id_){};
TData* p_dst_tmp = p_dst + (mDstSliceOrigin + merged_dim_data_id).GetOffset();
static_ford<DstNormalDimSliceLengthsHack>{}([&](auto normal_dim_data_id_) {
constexpr auto normal_dim_data_id = decltype(normal_dim_data_id_){};
constexpr index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(merged_dim_data_id + normal_dim_data_id);
constexpr index_t dst_normal_offset =
DstDesc::GetOffsetFromMultiIndex(normal_dim_data_id);
p_dst_tmp[dst_normal_offset] = p_buffer[buffer_offset];
});
});
#else
ford<DstMergedDimSliceLengthsHack>{}([&](auto merged_dim_data_id) {
TData* p_dst_tmp = p_dst + (mDstSliceOrigin + merged_dim_data_id).GetOffset();
ford<DstNormalDimSliceLengthsHack>{}([&](auto normal_dim_data_id) {
const index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(merged_dim_data_id + normal_dim_data_id);
const index_t dst_normal_offset =
DstDesc::GetOffsetFromMultiIndex(normal_dim_data_id);
p_dst_tmp[dst_normal_offset] = p_buffer[buffer_offset];
});
});
#endif
}
// T can be Sequence or Array
template <class T, bool PositiveDirection>
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
static_if<PositiveDirection>{}([&](auto) {
mSrcSliceOrigin += step_sizes;
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
}
template <class T, bool PositiveDirection>
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
static_if<PositiveDirection>{}([&](auto) {
mDstSliceOrigin += step_sizes;
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
private:
SrcCoordinate mSrcSliceOrigin;
DstCoordinate mDstSliceOrigin;
};
// This threadwise copy allow vector access of src and dst. // This threadwise copy allow vector access of src and dst.
// It allows the dimensions of vector access to be different on src and dst. // It allows the dimensions of vector access to be different on src and dst.
// It also allows the vector size to be different on src and dst. // It also allows the vector size to be different on src and dst.
...@@ -605,8 +437,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -605,8 +437,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2
// used for the buffer. // used for the buffer.
template <class SrcDesc, template <class SrcDesc,
class DstDesc, class DstDesc,
class SrcCoordinate,
class DstCoordinate,
class SliceLengths, class SliceLengths,
class SrcDimAccessOrder, class SrcDimAccessOrder,
class DstDimAccessOrder, class DstDimAccessOrder,
...@@ -618,6 +448,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1 ...@@ -618,6 +448,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
{ {
static constexpr index_t nDim = SliceLengths::GetSize(); static constexpr index_t nDim = SliceLengths::GetSize();
using SrcCoordinate = typename TensorCoordinate<SrcDesc>::type;
using DstCoordinate = typename TensorCoordinate<DstDesc>::type;
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1(SrcCoordinate src_slice_origin, __device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1(SrcCoordinate src_slice_origin,
DstCoordinate dst_slice_origin) DstCoordinate dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
......
...@@ -107,11 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -107,11 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>; using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0 #elif 0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1 // for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
...@@ -137,12 +137,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -137,12 +137,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<0, 0, 0, 0>; // not used using InBlockCopyClusterLengths_CHWN = Sequence<0, 0, 0, 0>; // not used
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Pascal // for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
...@@ -170,12 +170,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -170,12 +170,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; using WeiBlockCopySubLengths_CK = Sequence<2, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0 #elif 0
// for 3x3, 34x34, v1r3, Pascal, bad // for 3x3, 34x34, v1r3, Pascal, bad
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
...@@ -201,12 +204,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -201,12 +204,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<2, 2, 32, 1>; using InBlockCopyClusterLengths_CHWN = Sequence<2, 2, 32, 1>;
constexpr index_t InBlockCopyDataPerRead_N = 1; constexpr index_t InBlockCopyDataPerAccess_N = 1;
constexpr index_t WeiBlockCopyDataPerRead_K = 2; constexpr index_t WeiBlockCopyDataPerAccess_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 1; constexpr index_t OutThreadCopyDataPerAccess_N = 1;
#elif 0 #elif 0
// for 3x3, 34x34, v1r1, Vega 20 // for 3x3, 34x34, v1r1, Vega 20
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -232,12 +235,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -232,12 +235,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>; using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
constexpr index_t InBlockCopyDataPerRead_N = 2; constexpr index_t InBlockCopyDataPerAccess_N = 2;
constexpr index_t WeiBlockCopyDataPerRead_K = 2; constexpr index_t WeiBlockCopyDataPerAccess_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Vega 20 // for 3x3, 34x34, v1r3, Vega 20
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -263,12 +266,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -263,12 +266,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>; using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 0 #elif 0
// for 3x3, 56x56, v1r1, Pascal // for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
...@@ -282,13 +285,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -282,13 +285,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8; constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -298,7 +301,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -298,7 +301,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -324,14 +327,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -324,14 +327,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 1; constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1; constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -347,13 +350,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -347,13 +350,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8; constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -365,7 +368,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -365,7 +368,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -393,12 +396,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -393,12 +396,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>; using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0 #elif 0
// for 1x1, 28x28, v1r1, Pascal // for 1x1, 28x28, v1r1, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
...@@ -413,13 +416,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -413,13 +416,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -429,7 +432,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -429,7 +432,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -453,14 +456,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -453,14 +456,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
...@@ -478,9 +481,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -478,9 +481,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0 #elif 0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 0 #elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif #endif
<GridSize, <GridSize,
...@@ -507,10 +510,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -507,10 +510,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB, GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN, InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N, InBlockCopyDataPerAccess_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopySubLengths_CK,
OutThreadCopyDataPerWrite_N>{}; WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp" #include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" #include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" //#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
...@@ -85,7 +85,7 @@ int main(int argc, char* argv[]) ...@@ -85,7 +85,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -367,7 +367,7 @@ int main(int argc, char* argv[]) ...@@ -367,7 +367,7 @@ int main(int argc, char* argv[])
#if 0 #if 0
device_convolution_direct_v2_nchw_kcyx_nkhw device_convolution_direct_v2_nchw_kcyx_nkhw
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn( device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 0 #elif 0
...@@ -379,7 +379,7 @@ int main(int argc, char* argv[]) ...@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -409,7 +409,7 @@ int main(int argc, char* argv[]) ...@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment