Commit ef2664fb authored by Chao Liu's avatar Chao Liu
Browse files

adding bwd data v3r1

parent 36dfb2ad
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Ytilda*Xtilda number of GEMMs
// GemmM = C
// GemmN = N * Htilda * Wtilda;
// GemmK = K * slice(Ydot) * slice(Xdot);
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
false>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
false>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
false>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
false>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
// GEMMs
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
#else
static_for<1, 2, 1>{}([&](auto ytilda_) {
static_for<1, 2, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr index_t Ydotnonzero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t Xdotnonzero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - Ydotnonzero, Xdot - Xdotnonzero>>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc,
make_tuple(Merge<Sequence<K, Ydotnonzero, Xdotnonzero>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - Ydotnonzero, Xdot - Xdotnonzero>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotnonzero_htildatrim_xdotnonzero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, Ydotnonzero, Xdotnonzero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
});
});
}
};
} // namespace ck
#endif
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
......@@ -16,6 +16,7 @@
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
......@@ -32,7 +33,7 @@ int main(int argc, char* argv[])
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<3, 3>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
......@@ -171,7 +172,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......@@ -248,8 +249,10 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else
#elif 1
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#endif
(in_nchw_desc,
in_nchw_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment