Commit f7be86b9 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent d7079939
...@@ -158,24 +158,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -158,24 +158,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
#if 0 BlockwiseGenericTensorSliceCopy_v2<BlockSize,
BlockwiseGenericTensorSliceCopy_v1 decltype(in_e_n1_b_n2_global_merged_desc),
#else decltype(in_e_n1_b_n2_block_desc),
BlockwiseGenericTensorSliceCopy_v2 decltype(in_e_n1_b_n2_block_desc.GetLengths()),
#endif InBlockCopySubLengths_E_N1_B_N2,
<BlockSize, InBlockCopyClusterLengths_E_N1_B_N2,
decltype(in_e_n1_b_n2_global_merged_desc), InBlockCopyThreadClusterArrangeOrder,
decltype(in_e_n1_b_n2_block_desc), InBlockCopySrcAccessOrder,
decltype(in_e_n1_b_n2_block_desc.GetLengths()), InBlockCopyDstAccessOrder,
InBlockCopySubLengths_E_N1_B_N2, 2,
InBlockCopyClusterLengths_E_N1_B_N2, 3,
InBlockCopyThreadClusterArrangeOrder, InBlockCopySrcDataPerRead_B,
InBlockCopySrcAccessOrder, InBlockCopyDstDataPerWrite_N2>(
InBlockCopyDstAccessOrder, {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
...@@ -192,24 +188,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -192,24 +188,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
#if 0 BlockwiseGenericTensorSliceCopy_v2<BlockSize,
BlockwiseGenericTensorSliceCopy_v1 decltype(wei_e_k_global_desc),
#else decltype(wei_e_k_block_desc),
BlockwiseGenericTensorSliceCopy_v2 decltype(wei_e_k_block_desc.GetLengths()),
#endif WeiBlockCopySubLengths_E_K,
<BlockSize, WeiBlockCopyClusterLengths_E_K,
decltype(wei_e_k_global_desc), WeiBlockCopyThreadClusterArrangeOrder,
decltype(wei_e_k_block_desc), WeiBlockCopySrcAccessOrder,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyDstAccessOrder,
WeiBlockCopySubLengths_E_K, 0,
WeiBlockCopyClusterLengths_E_K, 1,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopySrcAccessOrder, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopyDstAccessOrder, {0, k_block_data_on_global}, {0, 0});
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -51,7 +51,7 @@ template <index_t GridSize, ...@@ -51,7 +51,7 @@ template <index_t GridSize,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
{ {
#if 1 #if 0
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
...@@ -437,6 +437,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -437,6 +437,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
// input
constexpr auto in_n_c_hi_wi_global_desc = constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
...@@ -465,6 +466,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -465,6 +466,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}), make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// weight
constexpr auto wei_e_k_global_desc =
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
...@@ -487,8 +495,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -487,8 +495,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
print_array("idx1: ", idx1); print_array("idx1: ", idx1);
print_array("idx0: ", idx0); print_array("idx0: ", idx0);
} }
#else
index_t itmp = get_block_1d_id() + get_thread_local_1d_id();
auto wei_coord1 = make_tensor_coordinate_v2(wei_e_k_global_desc, {itmp, itmp + 1});
auto step_sizes = make_multi_index(EPerBlock, 0);
wei_coord1 += step_sizes;
p_out_global[0] = wei_coord1.GetLowerCoordinate().GetIndex()[0];
p_out_global[1] = wei_coord1.GetLowerCoordinate().GetIndex()[1];
p_out_global[2] = wei_coord1.GetLowerCoordinate().GetIndex()[2];
p_out_global[3] = wei_coord1.GetLowerCoordinate().GetIndex()[3];
#endif #endif
p_out_global[0] = in_e_n1_b_n2_global_desc.CalculateOffset({0, 0, 10, 0});
} }
#endif #endif
}; };
......
...@@ -197,7 +197,7 @@ struct Merge ...@@ -197,7 +197,7 @@ struct Merge
// do carry check in reversed order, starting from lowest dimension // do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension // don't check the highest dimension
static_for<0, nDimLow, 1>{}([&](auto ireverse) { static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse; constexpr index_t i = nDimLow - 1 - ireverse;
if(carry) if(carry)
...@@ -213,6 +213,12 @@ struct Merge ...@@ -213,6 +213,12 @@ struct Merge
carry = true; carry = true;
} }
}); });
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
}
} }
else if(idx_up_diff[0] < 0) else if(idx_up_diff[0] < 0)
{ {
...@@ -220,7 +226,7 @@ struct Merge ...@@ -220,7 +226,7 @@ struct Merge
// do borrow check in reversed order, starting from lowest dimension // do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension // don't check the highest dimension
static_for<0, nDimLow, 1>{}([&](auto ireverse) { static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse; constexpr index_t i = nDimLow - 1 - ireverse;
if(borrow) if(borrow)
...@@ -236,6 +242,12 @@ struct Merge ...@@ -236,6 +242,12 @@ struct Merge
borrow = true; borrow = true;
} }
}); });
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
}
} }
return idx_low_new - idx_low_old; return idx_low_new - idx_low_old;
......
...@@ -70,7 +70,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -70,7 +70,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
......
...@@ -74,7 +74,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc, ...@@ -74,7 +74,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
......
...@@ -74,12 +74,12 @@ int main(int argc, char* argv[]) ...@@ -74,12 +74,12 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
constexpr index_t N = 512; constexpr index_t N = 256;
constexpr index_t C = 16; constexpr index_t C = 64;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 512; constexpr index_t K = 256;
constexpr index_t Y = 17; constexpr index_t Y = 17;
constexpr index_t X = 17; constexpr index_t X = 17;
...@@ -88,7 +88,7 @@ int main(int argc, char* argv[]) ...@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -117,8 +117,8 @@ int main(int argc, char* argv[]) ...@@ -117,8 +117,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51% // cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
...@@ -133,8 +133,8 @@ int main(int argc, char* argv[]) ...@@ -133,8 +133,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64% // cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
...@@ -149,8 +149,8 @@ int main(int argc, char* argv[]) ...@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65% // cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
...@@ -165,8 +165,8 @@ int main(int argc, char* argv[]) ...@@ -165,8 +165,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50% // cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
...@@ -181,8 +181,8 @@ int main(int argc, char* argv[]) ...@@ -181,8 +181,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61% // cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
...@@ -197,8 +197,8 @@ int main(int argc, char* argv[]) ...@@ -197,8 +197,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69% // cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
...@@ -213,8 +213,8 @@ int main(int argc, char* argv[]) ...@@ -213,8 +213,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62% // cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
...@@ -229,25 +229,9 @@ int main(int argc, char* argv[]) ...@@ -229,25 +229,9 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 17x17 input // 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76% // cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -261,8 +245,8 @@ int main(int argc, char* argv[]) ...@@ -261,8 +245,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64% // cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
...@@ -277,8 +261,8 @@ int main(int argc, char* argv[]) ...@@ -277,8 +261,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75% // cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
...@@ -293,8 +277,8 @@ int main(int argc, char* argv[]) ...@@ -293,8 +277,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52% // cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
...@@ -309,8 +293,24 @@ int main(int argc, char* argv[]) ...@@ -309,8 +293,24 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; using LeftPads = Sequence<0, 0>;
constexpr index_t WPad = 0; using RightPads = Sequence<0, 0>;
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#endif #endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment