Commit b7992190 authored by Chao Liu's avatar Chao Liu
Browse files

adding bwd data v2r1

parent cfff66cd
...@@ -19,8 +19,8 @@ template <index_t GridSize, ...@@ -19,8 +19,8 @@ template <index_t GridSize,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads, typename RightPads,
index_t BPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
...@@ -29,14 +29,14 @@ template <index_t GridSize, ...@@ -29,14 +29,14 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmThreadGemmDataPerReadM,
index_t GemmDataPerReadB, index_t GemmThreadGemmDataPerReadN,
typename OutBlockCopySubLengths_K_B,
typename OutBlockCopyClusterLengths_K_B,
index_t OutBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_K_E, typename WeiBlockCopySubLengths_K_E,
typename WeiBlockCopyClusterLengths_K_E, typename WeiBlockCopyClusterLengths_K_E,
index_t WeiBlockCopyDataPerAccess_E, index_t WeiBlockCopyDataPerAccess_E,
typename OutBlockCopySubLengths_K_B,
typename OutBlockCopyClusterLengths_K_B,
index_t OutBlockCopyDataPerAccess_B,
index_t InThreadCopyDataPerAccess_B> index_t InThreadCopyDataPerAccess_B>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{ {
...@@ -139,8 +139,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -139,8 +139,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, GemmThreadGemmDataPerReadM,
GemmDataPerReadB, GemmThreadGemmDataPerReadN,
WeiBlockCopySubLengths_K_E, WeiBlockCopySubLengths_K_E,
WeiBlockCopyClusterLengths_K_E, WeiBlockCopyClusterLengths_K_E,
WeiBlockCopyDataPerAccess_E, WeiBlockCopyDataPerAccess_E,
......
...@@ -22,8 +22,8 @@ template <index_t GridSize, ...@@ -22,8 +22,8 @@ template <index_t GridSize,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads, typename RightPads,
index_t BPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
......
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M
index_t GemmABlockCopyDataPerAccess, // Gemm-M
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N
index_t GemmBBlockCopyDataPerAccess, // Gemm-N
index_t GemmCThreadCopyDataPerAccess // Gemm-N
>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
static_assert(ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1,
"wrong! not supported yet");
// TODO: these logic are only for stride = 1, dilation = 1
constexpr index_t Ydot = Y;
constexpr index_t Ytilda = 1;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Xdot = X;
constexpr index_t Xtilda = 1;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(
PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>, Sequence<1, 1, 0>>{}, // coefficient may be wrong
Embed<Sequence<Xdot, Xtilda>, Sequence<1, 1, 0>>{}), // coefficient may be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<Y - 1, X - 1>>{}), // coefficient may
// be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>, Sequence<0, 1, 0>>{}, // coefficient may be wrong
Embed<Sequence<Xdot, Wtilda>, Sequence<0, 1, 0>>{}), // coefficient may be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto eff_left_pads = LeftPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto eff_right_pads = RightPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, decltype(eff_left_pads), decltype(eff_right_pads)>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
...@@ -49,10 +49,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -49,10 +49,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
// BlockSize = 256, each thread hold 64 data // BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t EPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t BPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t KPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
...@@ -60,27 +59,27 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -60,27 +59,27 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using OutBlockCopySubLengths_K_B = Sequence<4, 1>; using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using OutBlockCopyClusterLengths_K_B = Sequence<2, 128>; using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t OutBlockCopyDataPerAccess_B = 1; constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using WeiBlockCopySubLengths_K_E = Sequence<1, 4>; using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using WeiBlockCopyClusterLengths_K_E = Sequence<8, 32>; using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t WeiBlockCopyDataPerAccess_E = 4; constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t InThreadCopyDataPerAccess_B = 1; constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#endif #endif
constexpr index_t E = C * Y * X; constexpr index_t GemmM = C * Y * X;
constexpr index_t B = (N * Ho * Wo); constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock); ((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -96,9 +95,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -96,9 +95,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
ConvDilations, ConvDilations,
LeftPads, LeftPads,
RightPads, RightPads,
BPerBlock, GemmMPerBlock,
EPerBlock, GemmNPerBlock,
KPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
...@@ -106,15 +105,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -106,15 +105,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, GemmThreadGemmDataPerReadM,
GemmDataPerReadB, GemmThreadGemmDataPerReadN,
OutBlockCopySubLengths_K_B, GemmABlockCopySubLengths,
OutBlockCopyClusterLengths_K_B, GemmABlockCopyClusterLengths,
OutBlockCopyDataPerAccess_B, GemmABlockCopyDataPerAccess,
WeiBlockCopySubLengths_K_E, GemmBBlockCopySubLengths,
WeiBlockCopyClusterLengths_K_E, GemmBBlockCopyClusterLengths,
WeiBlockCopyDataPerAccess_E, GemmBBlockCopyDataPerAccess,
InThreadCopyDataPerAccess_B>{}; GemmCThreadCopyDataPerAccess>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -105,8 +105,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i ...@@ -105,8 +105,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
ConvDilations, ConvDilations,
LeftPads, LeftPads,
RightPads, RightPads,
BPerBlock,
EPerBlock, EPerBlock,
BPerBlock,
KPerBlock, KPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr index_t Ydot = 1;
constexpr index_t Ytilda = Y;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Xdot = 1;
constexpr index_t Xtilda = X;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "host_conv_bwd_data.hpp" #include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -34,7 +35,7 @@ int main(int argc, char* argv[]) ...@@ -34,7 +35,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -49,7 +50,7 @@ int main(int argc, char* argv[]) ...@@ -49,7 +50,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64; constexpr index_t N = 64;
...@@ -337,18 +338,23 @@ int main(int argc, char* argv[]) ...@@ -337,18 +338,23 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 0 #if 0
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#elif 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else #else
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#endif #endif
} }
#if 1 #if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#else #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment