Commit d2436a58 authored by Chao Liu's avatar Chao Liu
Browse files

v1r3 nchw*cyxk=nkhw lds double buffer

parent 63cdc6d2
...@@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2; constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
...@@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2; constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
...@@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1; constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...@@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0 #elif 0
// for 3x3, 34x34, v1r3, Vega 20 // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 2; constexpr index_t NPerBlock = 2;
...@@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using WeiBlockCopyClusterLengths = void; using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Vega 20, try // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 4; constexpr index_t NPerBlock = 4;
...@@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using WeiBlockCopyClusterLengths = void; using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 1 #elif 1
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
...@@ -206,8 +312,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -206,8 +312,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw< constexpr auto gridwise_conv =
GridSize, #if 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
#endif
<GridSize,
BlockSize, BlockSize,
T, T,
decltype(in_nchw_desc), decltype(in_nchw_desc),
......
...@@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); constexpr auto block_work_desc = make_ConstantTensorDescriptor(
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock; const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
...@@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) { const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1 #if 0
return blockwise_batch_gemm.Run(Xs...); return blockwise_batch_gemm.Run(Xs...);
#elif 0 #elif 0
return blockwise_batch_gemm.Run_asm(Xs...); return blockwise_batch_gemm.Run_asm(Xs...);
......
...@@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); constexpr auto block_work_desc = make_ConstantTensorDescriptor(
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock; const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
......
...@@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) { const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1 #if 0
return blockwise_batch_gemm.Run(Xs...); return blockwise_batch_gemm.Run(Xs...);
#elif 0 #elif 0
return blockwise_batch_gemm.Run_asm(Xs...); return blockwise_batch_gemm.Run_asm(Xs...);
...@@ -340,20 +340,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -340,20 +340,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -394,16 +395,14 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -394,16 +395,14 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex( out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -418,17 +417,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -418,17 +417,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<N / N1, Sequence<N / N1, N1, K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3>{});
N1,
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -454,8 +444,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -454,8 +444,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex( out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment