Commit 98716c83 authored by Chao Liu's avatar Chao Liu
Browse files

added bwd data v3r1

parent 9750de73
...@@ -46,6 +46,34 @@ template <index_t GridSize, ...@@ -46,6 +46,34 @@ template <index_t GridSize,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1> index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{ {
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(Float* __restrict__ p_in_global, __device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const const Float* __restrict__ p_out_global) const
...@@ -117,11 +145,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -117,11 +145,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>, Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
false>{}, true>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>, Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
false>{}), true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -205,155 +233,146 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -205,155 +233,146 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// get a series of GEMMs
auto f_get_gemm = [&](auto ytilda_, auto xtilda_) {
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr index_t Ydotnonzero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t Xdotnonzero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - Ydotnonzero, Xdot - Xdotnonzero>>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc,
make_tuple(Merge<Sequence<K, Ydotnonzero, Xdotnonzero>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - Ydotnonzero, Xdot - Xdotnonzero>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, Ydotnonzero, Xdotnonzero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
return gridwise_gemm;
};
// GEMMs // GEMMs
index_t shared_mem_size = 0; constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
static_for<0, Ytilda, 1>{}([&](auto ytilda) {
static_for<0, Xtilda, 1>{}([&](auto xtilda) {
auto gemm = f_get_gemm(ytilda, xtilda);
shared_mem_size = math::max(shared_mem_size, gemm.GetSharedMemorySize());
});
});
__shared__ Float p_shared_float[shared_mem_size / sizeof(Float)]; __shared__ Float p_shared_block[shared_block_size];
// GEMMs #if 1 // debug
static_for<0, Ytilda, 1>{}([&](auto ytilda) { static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda) { static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
auto gemm = f_get_gemm(ytilda, xtilda); #else
static_for<0, 1, 1>{}([&](auto ytilda_) {
gemm.Run(p_wei_global, p_in_global, p_out_global, p_shared_float); static_for<0, 1, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
}); });
}); });
} }
......
...@@ -50,7 +50,7 @@ template <index_t GridSize, ...@@ -50,7 +50,7 @@ template <index_t GridSize,
index_t CThreadCopyDstDataPerWrite> index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1 struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemorySize() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N, BBlockCopyDstDataPerWrite_N,
...@@ -80,7 +80,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -80,7 +80,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
void* __restrict__ p_shared) const Float* __restrict__ p_shared_block) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
...@@ -92,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -92,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr auto M = a_k_m_global_desc.GetLengths()[1]; constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1]; constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N, BBlockCopyDstDataPerWrite_N,
...@@ -212,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -212,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr index_t b_block_space = constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = reinterpret_cast<Float*>(p_shared); Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_a_block_double + 2 * a_block_space; Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
...@@ -362,11 +368,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -362,11 +368,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const Float* __restrict__ p_c_global) const
{ {
constexpr index_t shared_mem_size = GetSharedMemorySize(); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_float[shared_mem_size / sizeof(Float)]; __shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_float); Run(p_a_global, p_b_global, p_c_global, p_shared_block);
} }
}; };
......
...@@ -84,36 +84,6 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i ...@@ -84,36 +84,6 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
......
...@@ -25,10 +25,10 @@ int main(int argc, char* argv[]) ...@@ -25,10 +25,10 @@ int main(int argc, char* argv[])
#if 1 #if 1
// 3x3 filter, 2x2 stride, 35x35 input // 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -157,7 +157,7 @@ int main(int argc, char* argv[]) ...@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment