"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "fce888f4c0c3eeca9f8d5c0f32b3519babcaa17b"
Commit b7992190 authored by Chao Liu's avatar Chao Liu
Browse files

adding bwd data v2r1

parent cfff66cd
......@@ -19,8 +19,8 @@ template <index_t GridSize,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock,
index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
......@@ -29,14 +29,14 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename OutBlockCopySubLengths_K_B,
typename OutBlockCopyClusterLengths_K_B,
index_t OutBlockCopyDataPerAccess_B,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename WeiBlockCopySubLengths_K_E,
typename WeiBlockCopyClusterLengths_K_E,
index_t WeiBlockCopyDataPerAccess_E,
typename OutBlockCopySubLengths_K_B,
typename OutBlockCopyClusterLengths_K_B,
index_t OutBlockCopyDataPerAccess_B,
index_t InThreadCopyDataPerAccess_B>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
......@@ -139,8 +139,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
WeiBlockCopySubLengths_K_E,
WeiBlockCopyClusterLengths_K_E,
WeiBlockCopyDataPerAccess_E,
......
......@@ -22,8 +22,8 @@ template <index_t GridSize,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock,
index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
......
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M
index_t GemmABlockCopyDataPerAccess, // Gemm-M
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N
index_t GemmBBlockCopyDataPerAccess, // Gemm-N
index_t GemmCThreadCopyDataPerAccess // Gemm-N
>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
static_assert(ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1,
"wrong! not supported yet");
// TODO: these logic are only for stride = 1, dilation = 1
constexpr index_t Ydot = Y;
constexpr index_t Ytilda = 1;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Xdot = X;
constexpr index_t Xtilda = 1;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(
PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>, Sequence<1, 1, 0>>{}, // coefficient may be wrong
Embed<Sequence<Xdot, Xtilda>, Sequence<1, 1, 0>>{}), // coefficient may be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<Y - 1, X - 1>>{}), // coefficient may
// be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>, Sequence<0, 1, 0>>{}, // coefficient may be wrong
Embed<Sequence<Xdot, Wtilda>, Sequence<0, 1, 0>>{}), // coefficient may be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto eff_left_pads = LeftPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto eff_right_pads = RightPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, decltype(eff_left_pads), decltype(eff_right_pads)>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
......@@ -49,38 +49,37 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t EPerBlock = 128;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using OutBlockCopySubLengths_K_B = Sequence<4, 1>;
using OutBlockCopyClusterLengths_K_B = Sequence<2, 128>;
constexpr index_t OutBlockCopyDataPerAccess_B = 1;
using WeiBlockCopySubLengths_K_E = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_K_E = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_E = 4;
constexpr index_t InThreadCopyDataPerAccess_B = 1;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#endif
constexpr index_t E = C * Y * X;
constexpr index_t B = (N * Ho * Wo);
constexpr index_t GemmM = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize =
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock);
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......@@ -96,9 +95,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
ConvDilations,
LeftPads,
RightPads,
BPerBlock,
EPerBlock,
KPerBlock,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
......@@ -106,15 +105,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
OutBlockCopySubLengths_K_B,
OutBlockCopyClusterLengths_K_B,
OutBlockCopyDataPerAccess_B,
WeiBlockCopySubLengths_K_E,
WeiBlockCopyClusterLengths_K_E,
WeiBlockCopyDataPerAccess_E,
InThreadCopyDataPerAccess_B>{};
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
for(index_t i = 0; i < nrepeat; ++i)
{
......
......@@ -105,8 +105,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
ConvDilations,
LeftPads,
RightPads,
BPerBlock,
EPerBlock,
BPerBlock,
KPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr index_t Ydot = 1;
constexpr index_t Ytilda = Y;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Xdot = 1;
constexpr index_t Xtilda = X;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
......@@ -15,6 +15,7 @@
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
......@@ -34,7 +35,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -49,7 +50,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64;
......@@ -337,18 +338,23 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 0
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#elif 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#endif
}
#if 1
#if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#else
#elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#endif
(in_nchw_desc,
in_nchw_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment