"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2"
Commit d2436a58 authored by Chao Liu's avatar Chao Liu
Browse files

v1r3 nchw*cyxk=nkhw lds double buffer

parent 63cdc6d2
...@@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2; constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
...@@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2; constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
...@@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1; constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...@@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0 #elif 0
// for 3x3, 34x34, v1r3, Vega 20 // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 2; constexpr index_t NPerBlock = 2;
...@@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using WeiBlockCopyClusterLengths = void; using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Vega 20, try // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 4; constexpr index_t NPerBlock = 4;
...@@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using WeiBlockCopyClusterLengths = void; using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 1 #elif 1
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
...@@ -206,39 +312,44 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -206,39 +312,44 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw< constexpr auto gridwise_conv =
GridSize, #if 0
BlockSize, GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
T, #else
decltype(in_nchw_desc), GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
decltype(wei_cyxk_desc), #endif
decltype(out_nkhw_desc), <GridSize,
NPerBlock, BlockSize,
KPerBlock, T,
CPerBlock, decltype(in_nchw_desc),
HoPerBlock, decltype(wei_cyxk_desc),
WoPerBlock, decltype(out_nkhw_desc),
NPerThread, NPerBlock,
KPerThread, KPerBlock,
HoPerThread, CPerBlock,
WoPerThread, HoPerBlock,
GemmMPerThreadSubC, WoPerBlock,
GemmNPerThreadSubC, NPerThread,
GemmMLevel0Cluster, KPerThread,
GemmNLevel0Cluster, HoPerThread,
GemmMLevel1Cluster, WoPerThread,
GemmNLevel1Cluster, GemmMPerThreadSubC,
GemmKPerThreadLoop, GemmNPerThreadSubC,
GemmDataPerReadA, GemmMLevel0Cluster,
GemmDataPerReadB, GemmNLevel0Cluster,
InBlockReorderSrcSubLengths_NCHW, GemmMLevel1Cluster,
InBlockReorderSrcClusterLengths_NCHW, GemmNLevel1Cluster,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, GemmKPerThreadLoop,
InBlockReorderDataPerRead_W, GemmDataPerReadA,
InBlockReorderDataPerWrite_N, GemmDataPerReadB,
WeiBlockCopyClusterLengths, InBlockReorderSrcSubLengths_NCHW,
WeiBlockCopyDataPerRead_K, InBlockReorderSrcClusterLengths_NCHW,
OutThreadCopyDataPerWrite_W>{}; InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
......
...@@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, ...@@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t ho = HoPerTile * htile + j; std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i) for(int i = 0; i < WoPerTile; ++i)
{ {
std::size_t wo = WoPerTile * wtile + i; std::size_t wo = WoPerTile * wtile + i;
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
} }
} }
...@@ -413,13 +413,13 @@ int main(int argc, char* argv[]) ...@@ -413,13 +413,13 @@ int main(int argc, char* argv[])
{ {
#if 1 #if 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
......
...@@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); constexpr auto block_work_desc = make_ConstantTensorDescriptor(
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock; const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
...@@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) { const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1 #if 0
return blockwise_batch_gemm.Run(Xs...); return blockwise_batch_gemm.Run(Xs...);
#elif 0 #elif 0
return blockwise_batch_gemm.Run_asm(Xs...); return blockwise_batch_gemm.Run_asm(Xs...);
......
...@@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); constexpr auto block_work_desc = make_ConstantTensorDescriptor(
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock; const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
......
...@@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) { const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1 #if 0
return blockwise_batch_gemm.Run(Xs...); return blockwise_batch_gemm.Run(Xs...);
#elif 0 #elif 0
return blockwise_batch_gemm.Run_asm(Xs...); return blockwise_batch_gemm.Run_asm(Xs...);
...@@ -340,39 +340,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -340,39 +340,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// make this lambda a generic lambda, so it won't be compiled until // Using this trick to
// instantiated // make this lambda a generic lambda, so it won't be compiled until
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && // instantiated
NPerBlock % GemmNPerThreadSubC == 0), static_assert(
"wrong!"); (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; // output is a 10d tensor
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
f_dummy(NPerBlock / GemmNPerThreadSubC); constexpr index_t W2 =
constexpr index_t W1 = WoPerBlock / W2; (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2), constexpr auto out_10d_global_desc =
N1, make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
N2, N1,
K / (K1 * K2), N2,
K1, K / (K1 * K2),
K2, K1,
Ho, K2,
Wo / (W1 * W2), Ho,
W1, Wo / (W1 * W2),
W2>{}); W1,
W2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{}); constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -387,51 +388,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -387,51 +388,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
} }
#endif #endif
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2( threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc, out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex( out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
n_block_data_begin + n_thread_data_begin, k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin),
wo_block_data_begin + wo_thread_data_begin), out_10d_thread_desc.GetLengths(),
out_10d_thread_desc.GetLengths(), map_out_global2thread);
map_out_global2thread); // Number<OutThreadCopyDataPerWrite_W>{});
// Number<OutThreadCopyDataPerWrite_W>{}); }).else_([&](auto f_dummy) {
}) static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
.else_([&](auto f_dummy) { GemmNPerThreadSubC % NPerThread == 0,
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && "wrong!");
GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); // output is a 10d tensor
constexpr index_t N1 = NPerBlock;
// output is a 10d tensor
constexpr index_t N1 = NPerBlock; constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<N / N1, N1, K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3>{});
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / N1, constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
N1, Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -447,21 +437,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -447,21 +437,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
} }
#endif #endif
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2( threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc, out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex( out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
n_block_data_begin + n_thread_data_begin, k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin),
wo_block_data_begin + wo_thread_data_begin), out_10d_thread_desc.GetLengths(),
out_10d_thread_desc.GetLengths(), map_out_global2thread);
map_out_global2thread); // Number<OutThreadCopyDataPerWrite_W>{});
// Number<OutThreadCopyDataPerWrite_W>{}); });
});
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment