"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "6bdd3830208951d6020e3fea0c8fbf21871d3344"
Unverified Commit 1a66e35b authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

MIopen integration (#13)

* update for miopen integration: cosmetic refactor
parent 3406a114
...@@ -54,21 +54,23 @@ ...@@ -54,21 +54,23 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
namespace ck { namespace ck {
enum AddressSpace enum AddressSpace
{ {
generic, Generic,
global, Global,
lds, Lds,
vgpr Vgpr
}; };
enum InMemoryDataOperation enum InMemoryDataOperation
{ {
none, Set,
atomic_add AtomicAdd
}; };
#if CK_UNSIGNED_INDEX_TYPE #if CK_UNSIGNED_INDEX_TYPE
......
...@@ -10,13 +10,14 @@ template <typename T, ...@@ -10,13 +10,14 @@ template <typename T,
index_t DataPerAccess, index_t DataPerAccess,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace> AddressSpace DstAddressSpace>
__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) __device__ void set_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{ {
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType; using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
// TODO: use static_if::ElseIf, instead of nested static_if // TODO: use static_if::ElseIf, instead of nested static_if
static_if<SrcAddressSpace == AddressSpace::global && DstAddressSpace == vgpr>{}([&](auto) { static_if<SrcAddressSpace == AddressSpace::Global &&
DstAddressSpace == AddressSpace::Vgpr>{}([&](auto) {
// buffer_load requires: // buffer_load requires:
// 1) p_src must be in global memory space, d_dst must be vgpr // 1) p_src must be in global memory space, d_dst must be vgpr
// 2) p_src to be a block-invariant pointer. // 2) p_src to be a block-invariant pointer.
...@@ -24,7 +25,8 @@ __device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t ...@@ -24,7 +25,8 @@ __device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = *reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
amd_intrinsic_buffer_load<T, DataPerAccess>(p_src, src_offset, 0); amd_intrinsic_buffer_load<T, DataPerAccess>(p_src, src_offset, 0);
}).Else([&](auto) { }).Else([&](auto) {
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == global>{}([&](auto) { static_if<SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
// buffer_store requires: // buffer_store requires:
// 1) p_src must be in vgpr space, d_dst must be global memory // 1) p_src must be in vgpr space, d_dst must be global memory
// 2) p_dst to be a block-invariant pointer. // 2) p_dst to be a block-invariant pointer.
...@@ -50,19 +52,18 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in ...@@ -50,19 +52,18 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
{ {
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType; using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}( static_if<SrcAddressSpace == AddressSpace::Vgpr &&
[&](auto) { DstAddressSpace == AddressSpace::Global>{}([&](auto) {
#if CK_USE_AMD_BUFFER_ATOMIC_ADD #if CK_USE_AMD_BUFFER_ATOMIC_ADD
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>( amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0); *reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
#else #else
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]), atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset])); *reinterpret_cast<const vector_t*>(&p_src[src_offset]));
#endif #endif
}) }).Else([&](auto fwd) {
.Else([&](auto fwd) { static_assert(fwd(false), "atomic_add doesn't support this memory space");
static_assert(fwd(false), "atomic_add doesn't support this memory space"); });
});
} }
template <typename T, template <typename T,
...@@ -72,17 +73,17 @@ template <typename T, ...@@ -72,17 +73,17 @@ template <typename T,
InMemoryDataOperation DstInMemOp> InMemoryDataOperation DstInMemOp>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{ {
static_assert(DstInMemOp == InMemoryDataOperation::none || static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::atomic_add, DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!"); "wrong! InMemoryDataOperation not supported!");
// TODO: use static_if::ElseIf // TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::none>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>( set_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset); p_src, src_offset, p_dst, dst_offset);
}); });
static_if<DstInMemOp == InMemoryDataOperation::atomic_add>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>( atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset); p_src, src_offset, p_dst, dst_offset);
}); });
......
...@@ -23,14 +23,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in ...@@ -23,14 +23,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
{ {
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType; using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}( static_if<SrcAddressSpace == AddressSpace::Vgpr &&
[&](auto) { DstAddressSpace == AddressSpace::Global>{}([&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]), atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset])); *reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}) }).Else([&](auto fwd) {
.Else([&](auto fwd) { static_assert(fwd(false), "atomic_add doesn't support this memory space");
static_assert(fwd(false), "atomic_add doesn't support this memory space"); });
});
} }
template <typename T, template <typename T,
...@@ -40,17 +39,17 @@ template <typename T, ...@@ -40,17 +39,17 @@ template <typename T,
InMemoryDataOperation DstInMemOp> InMemoryDataOperation DstInMemOp>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{ {
static_assert(DstInMemOp == InMemoryDataOperation::none || static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::atomic_add, DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!"); "wrong! InMemoryDataOperation not supported!");
// TODO: use static_if::ElseIf // TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::none>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>( copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset); p_src, src_offset, p_dst, dst_offset);
}); });
static_if<DstInMemOp == InMemoryDataOperation::atomic_add>{}([&](auto) { static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>( atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset); p_src, src_offset, p_dst, dst_offset);
}); });
......
...@@ -107,27 +107,22 @@ __host__ __device__ constexpr T min(T x, Ts... xs) ...@@ -107,27 +107,22 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
template <typename T> template <typename T>
__host__ __device__ constexpr T gcd(T x, T y) __host__ __device__ constexpr T gcd(T x, T y)
{ {
if(x == 0) if(x == y || x == 0)
{ {
return y; return y;
} }
else if(y == 0)
if(y == 0)
{ {
return x; return x;
} }
else if(x > y)
if(x == y)
{ {
return x; return gcd(x - y, y);
} }
else
if(x > y)
{ {
return gcd(x - y, y); return gcd(x, y - x);
} }
return gcd(x, y - x);
} }
template <index_t X, index_t Y> template <index_t X, index_t Y>
...@@ -150,10 +145,10 @@ __host__ __device__ constexpr T lcm(T x, T y) ...@@ -150,10 +145,10 @@ __host__ __device__ constexpr T lcm(T x, T y)
return (x * y) / gcd(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, typename Y, typename... Zs> template <typename X, typename... Ys>
__host__ __device__ constexpr auto lcm(X x, Y y, Zs... zs) __host__ __device__ constexpr auto lcm(X x, Ys... ys)
{ {
return lcm(x, lcm(y, zs...)); return lcm(x, lcm(ys...));
} }
template <class T> template <class T>
......
...@@ -49,20 +49,20 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -49,20 +49,20 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, each thread hold 64 data // BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -79,6 +79,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -79,6 +79,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif #endif
constexpr index_t GemmM = C * Y * X; constexpr index_t GemmM = C * Y * X;
...@@ -104,13 +134,13 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -104,13 +134,13 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -96,13 +96,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -96,13 +96,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -127,13 +127,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -127,13 +127,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -152,33 +152,33 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -152,33 +152,33 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif #endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor( constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor( constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min( constexpr index_t HTildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min( constexpr index_t WTildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * YTilda * XTilda;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
...@@ -200,13 +200,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -200,13 +200,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i ...@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -91,33 +91,33 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i ...@@ -91,33 +91,33 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor( constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor( constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min( constexpr index_t HTildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min( constexpr index_t WTildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr index_t GemmM = C; constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
...@@ -139,13 +139,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i ...@@ -139,13 +139,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -69,13 +69,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -69,13 +69,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -99,13 +99,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -99,13 +99,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
...@@ -124,33 +124,33 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -124,33 +124,33 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor( constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor( constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min( constexpr index_t HTildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min( constexpr index_t WTildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr index_t GemmM = C; constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
...@@ -159,7 +159,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -159,7 +159,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< using GridwiseConvBwdData = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
T, T,
...@@ -174,13 +174,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -174,13 +174,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
...@@ -196,21 +196,29 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -196,21 +196,29 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
KernelTimer timer; KernelTimer timer;
timer.Start(); timer.Start();
static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) { static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
constexpr index_t gemm_id = decltype(gemm_id_){}; constexpr index_t gemm_id = decltype(gemm_id_){};
launch_kernel(run_gridwise_convolution_backward_data_v4r1<GridwiseConv, constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
gemm_id, constexpr index_t gemm_k = gemm_sizes.At(2);
T* const __restrict__, constexpr bool is_gemm_not_empty = gemm_k > 0;
const T* const __restrict__,
const T* const __restrict__>, // only compile and run if GEMM is no empty
dim3(GridSize), static_if<is_gemm_not_empty>{}([&](auto fwd) {
dim3(BlockSize), launch_kernel(
0, run_gridwise_convolution_backward_data_v4r1<GridwiseConvBwdData,
0, fwd(gemm_id),
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), T* const __restrict__,
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), const T* const __restrict__,
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
});
}); });
timer.End(); timer.End();
......
...@@ -23,17 +23,16 @@ int main(int argc, char* argv[]) ...@@ -23,17 +23,16 @@ int main(int argc, char* argv[])
{ {
using namespace launcher; using namespace launcher;
#if 0 #if 1
// 3x3 filter, 2x2 stride, 35x35 input constexpr index_t N = 64;
constexpr index_t N = 128; constexpr index_t C = 256;
constexpr index_t C = 1024; constexpr index_t HI = 56;
constexpr index_t HI = 35; constexpr index_t WI = 56;
constexpr index_t WI = 35; constexpr index_t K = 256;
constexpr index_t K = 1024; constexpr index_t Y = 1;
constexpr index_t Y = 3; constexpr index_t X = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
...@@ -158,7 +157,7 @@ int main(int argc, char* argv[]) ...@@ -158,7 +157,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -188,7 +187,7 @@ int main(int argc, char* argv[]) ...@@ -188,7 +187,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment