Unverified Commit c5da0377 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Added bwd data v3r1 v4r1, tweaking v1 (#10)

* Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel
* Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels
* Tweaked v1r1  and v1r2 (atomic) on AMD GPU
parent e2b4c5b4
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T, template <typename T,
typename InDesc, typename InDesc,
typename WeiDesc, typename WeiDesc,
...@@ -121,20 +125,18 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -121,20 +125,18 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv), float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__, T* const __restrict__,
const T* const __restrict__, const T* const __restrict__,
const T* const __restrict__>, const T* const __restrict__>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
gridwise_conv, gridwise_conv,
const_cast<T* const __restrict__>( static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
const_cast<const T* const __restrict__>( static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
time, time,
...@@ -145,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -145,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data()); in_nchw_device_buf.FromDevice(in_nchw.mData.data());
} }
} // namespace launcher
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
namespace launcher {
using namespace ck;
template <typename T, template <typename T,
typename InDesc, typename InDesc,
typename WeiDesc, typename WeiDesc,
...@@ -129,20 +133,18 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i ...@@ -129,20 +133,18 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv), float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__, T* const __restrict__,
const T* const __restrict__, const T* const __restrict__,
const T* const __restrict__>, const T* const __restrict__>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
gridwise_conv, gridwise_conv,
const_cast<T* const __restrict__>( static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
const_cast<const T* const __restrict__>( static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
time, time,
...@@ -153,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i ...@@ -153,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data()); in_nchw_device_buf.FromDevice(in_nchw.mData.data());
} }
} // namespace launcher
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T, template <typename T,
typename InDesc, typename InDesc,
typename WeiDesc, typename WeiDesc,
...@@ -29,10 +33,14 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -29,10 +33,14 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t N = out_nkhw_desc.GetLengths()[0]; constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1]; constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
...@@ -81,6 +89,67 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -81,6 +89,67 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
...@@ -92,14 +161,24 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -92,14 +161,24 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda); constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t Htilda = Ho + right_pad_ho; constexpr index_t HtildaRight = math::min(
constexpr index_t Wtilda = Wo + right_pad_wo; Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda; constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
...@@ -142,20 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -142,20 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv), float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__, T* const __restrict__,
const T* const __restrict__, const T* const __restrict__,
const T* const __restrict__>, const T* const __restrict__>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
gridwise_conv, gridwise_conv,
const_cast<T* const __restrict__>( static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
const_cast<const T* const __restrict__>( static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
time, time,
...@@ -166,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -166,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data()); in_nchw_device_buf.FromDevice(in_nchw.mData.data());
} }
} // namespace launcher
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
KernelTimer timer;
timer.Start();
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr auto gridwise_conv =
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ytilda,
xtilda,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
});
});
timer.End();
float time = timer.GetElapsedTime();
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
...@@ -82,7 +82,7 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc, ...@@ -82,7 +82,7 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
WoPerThread, WoPerThread,
InBlockCopyDataPerRead, InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>; WeiBlockCopyDataPerRead>;
float time = launch_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>, float time = launch_and_time_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc, ...@@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{}; OutThreadCopyDataPerWrite_W>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc, ...@@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyDataPerRead, WeiBlockCopyDataPerRead,
OutThreadCopyDataPerWrite>{}; OutThreadCopyDataPerWrite>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc, ...@@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
WeiBlockCopyClusterLengths_C_K, WeiBlockCopyClusterLengths_C_K,
WeiBlockCopyDataPerAccess_K>{}; WeiBlockCopyDataPerAccess_K>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// BlockSize = 256, EperBlock = 8, each thread hold 64 data // BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// BlockSize = 64, each thread hold 64 data // BlockSize = 64, each thread hold 64 data
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -258,13 +258,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -258,13 +258,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv), float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
const T* const __restrict__, const T* const __restrict__,
const T* const __restrict__, const T* const __restrict__,
T* const __restrict__>, T* const __restrict__>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
gridwise_conv, gridwise_conv,
const_cast<const T* const __restrict__>( const_cast<const T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
......
...@@ -81,6 +81,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, ...@@ -81,6 +81,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 0
...@@ -247,10 +284,12 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, ...@@ -247,10 +284,12 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
......
...@@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, ...@@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{}; WeiBlockCopyDstDataPerWrite_K>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc, ...@@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{}; WeiBlockCopyDstDataPerWrite_K>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -83,6 +83,37 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -83,6 +83,37 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, GemmKPerBlock = 16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
...@@ -116,7 +147,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -116,7 +147,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 1
// BlockSize = 256, GemmKPerBlock = 16 // BlockSize = 256, GemmKPerBlock = 16
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -225,10 +256,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -225,10 +256,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
......
...@@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc, ...@@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_and_time_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei, gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
TOut, TOut,
accum_t, accum_t,
......
...@@ -16,22 +16,25 @@ ...@@ -16,22 +16,25 @@
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace launcher;
#if 1 #if 0
constexpr index_t N = 8; // 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t C = 128; constexpr index_t N = 128;
constexpr index_t HI = 16; constexpr index_t C = 1024;
constexpr index_t WI = 16; constexpr index_t HI = 35;
constexpr index_t K = 8; constexpr index_t WI = 35;
constexpr index_t Y = 2; constexpr index_t K = 1024;
constexpr index_t X = 2; constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<4, 4>; using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
...@@ -41,7 +44,7 @@ int main(int argc, char* argv[]) ...@@ -41,7 +44,7 @@ int main(int argc, char* argv[])
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
constexpr index_t K = 128; constexpr index_t K = 256;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -51,27 +54,27 @@ int main(int argc, char* argv[]) ...@@ -51,27 +54,27 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 3x3, 28x28
constexpr index_t N = 64; constexpr index_t N = 128;
constexpr index_t C = 1536; constexpr index_t C = 1024;
constexpr index_t HI = 8; constexpr index_t HI = 28;
constexpr index_t WI = 8; constexpr index_t WI = 28;
constexpr index_t K = 256; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 3;
constexpr index_t X = 1; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t N = 128; constexpr index_t N = 256;
constexpr index_t C = 2048; constexpr index_t C = 1024;
constexpr index_t HI = 8; constexpr index_t HI = 8;
constexpr index_t WI = 8; constexpr index_t WI = 8;
constexpr index_t K = 384; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -83,25 +86,10 @@ int main(int argc, char* argv[]) ...@@ -83,25 +86,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 832; constexpr index_t C = 1024;
constexpr index_t HI = 7; constexpr index_t HI = 7;
constexpr index_t WI = 7; constexpr index_t WI = 7;
constexpr index_t K = 384; constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
constexpr index_t N = 128;
constexpr index_t C = 1280;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 384;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -123,27 +111,12 @@ int main(int argc, char* argv[]) ...@@ -123,27 +111,12 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
constexpr index_t N = 64;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 384;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -153,30 +126,15 @@ int main(int argc, char* argv[]) ...@@ -153,30 +126,15 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 17x17 input // 1x1 filter, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 768; constexpr index_t C = 1024;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -186,57 +144,57 @@ int main(int argc, char* argv[]) ...@@ -186,57 +144,57 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 5x5 filter, 2x2 pad, 7x7 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 528; constexpr index_t C = 1024;
constexpr index_t HI = 14; constexpr index_t HI = 7;
constexpr index_t WI = 14; constexpr index_t WI = 7;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 5;
constexpr index_t X = 1; constexpr index_t X = 5;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<2, 2>;
#elif 0 #elif 1
// 1x1 filter, 14x14 image // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 528; constexpr index_t C = 1024;
constexpr index_t HI = 14; constexpr index_t HI = 17;
constexpr index_t WI = 14; constexpr index_t WI = 17;
constexpr index_t K = 256; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 3>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 832; constexpr index_t C = 1024;
constexpr index_t HI = 7; constexpr index_t HI = 17;
constexpr index_t WI = 7; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 7;
constexpr index_t X = 1; constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 1024;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 384; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -245,51 +203,6 @@ int main(int argc, char* argv[]) ...@@ -245,51 +203,6 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
// 5x5 filter, 2x2 pad, 7x7 input
constexpr index_t N = 128;
constexpr index_t C = 48;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 128;
constexpr index_t Y = 5;
constexpr index_t X = 5;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#endif #endif
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{}); constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
...@@ -337,8 +250,12 @@ int main(int argc, char* argv[]) ...@@ -337,8 +250,12 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else #elif 1
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
......
...@@ -30,32 +30,63 @@ int main(int argc, char* argv[]) ...@@ -30,32 +30,63 @@ int main(int argc, char* argv[])
using namespace ck; using namespace ck;
#if 0 #if 0
constexpr index_t N = 8; // 1x1
constexpr index_t C = 32; constexpr index_t N = 256;
constexpr index_t HI = 28; constexpr index_t C = 1024;
constexpr index_t WI = 28; constexpr index_t HI = 8;
constexpr index_t K = 32; constexpr index_t WI = 8;
constexpr index_t Y = 5; constexpr index_t K = 1024;
constexpr index_t X = 5; constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 1x7
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
constexpr index_t K = 128; constexpr index_t K = 256;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
...@@ -282,35 +313,35 @@ int main(int argc, char* argv[]) ...@@ -282,35 +313,35 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 0 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 7; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<0, 3>;
#elif 1 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 7;
constexpr index_t X = 7; constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<3, 0>;
#endif #endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment