Commit bd5a1bc2 authored by Chao Liu's avatar Chao Liu
Browse files

add bwd-data-v4r1 nhwc

parent e9c5efc4
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Number of GEMMs = YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = YDotSlice * XDotSlice * K
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmK,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nhwc_kyxc_nhwk
{
__host__ __device__ static constexpr index_t GetNumberOfGemm()
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
return YTilda * XTilda;
}
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
{
constexpr index_t N = InGlobalDesc::GetLengths()[0];
constexpr index_t Hi = InGlobalDesc::GetLengths()[1];
constexpr index_t Wi = InGlobalDesc::GetLengths()[2];
constexpr index_t C = InGlobalDesc::GetLengths()[3];
constexpr index_t Ho = OutGlobalDesc::GetLengths()[1];
constexpr index_t Wo = OutGlobalDesc::GetLengths()[2];
constexpr index_t K = OutGlobalDesc::GetLengths()[3];
constexpr index_t Y = WeiGlobalDesc::GetLengths()[1];
constexpr index_t X = WeiGlobalDesc::GetLengths()[2];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr index_t iHTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t iWTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t iHTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t iWTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// GemmM and GemmN
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
// GemmK is different for each GEMM
index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
index_t GemmK = YDotSlice * XDotSlice * K;
return Array<index_t, 3>{GemmM, GemmN, GemmK};
}
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
{
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
index_t iYTilda = gemm_id / XTilda;
index_t iXTilda = gemm_id % XTilda;
return GetGemmSizeImpl(iYTilda, iXTilda);
}
template <index_t iYTilda, index_t iXTilda>
__device__ static void RunImpl(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0];
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1];
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2];
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3];
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1];
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2];
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1];
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr index_t iHTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t iWTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t iHTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t iWTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// A matrix: weight
// weight out-of-bound check can be skipped
constexpr bool wei_skip_out_of_bound_check = true;
constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor(
wei_k_y_x_c_global_desc,
make_tuple(PassThrough<K>{},
Embed<Y,
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
wei_skip_out_of_bound_check>{},
Embed<X,
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor(
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc,
make_tuple(
PassThrough<K>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_global_desc,
make_tuple(Merge<Sequence<YDotSlice, XDotSlice, K>>{}, PassThrough<C>{}),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool out_skip_out_of_bound_check = false;
#else
constexpr bool out_skip_out_of_bound_check = true;
#endif
constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor(
out_n_ho_wo_k_global_desc,
make_tuple(PassThrough<N>{},
Embed<Ho,
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
out_skip_out_of_bound_check>{},
Embed<Wo,
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_out_of_bound_check>{},
PassThrough<K>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc =
transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_global_desc,
make_tuple(
PassThrough<N>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{},
PassThrough<K>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc,
make_tuple(Merge<Sequence<YDotSlice, XDotSlice, K>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool in_skip_out_of_bound_check = false;
#else
constexpr bool in_skip_out_of_bound_check = true;
#endif
constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor(
in_n_hi_wi_c_global_desc,
make_tuple(PassThrough<N>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1];
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2];
constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor(
in_n_hip_wip_c_global_desc,
make_tuple(PassThrough<N>{},
Embed<Hip,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_out_of_bound_check>{},
Embed<Wip,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc,
make_tuple(PassThrough<N>{},
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_global_desc,
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// call GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
0,
GemmBBlockCopySrcDataPerRead_GemmK,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
template <index_t GemmId>
__device__ static void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global,
Number<GemmId>)
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t iYTilda = GemmId / XTilda;
constexpr index_t iXTilda = GemmId % XTilda;
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
}
};
} // namespace ck
#endif
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1_nhwc_kyxc_nhwk(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr auto in_nhwc_desc = make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
constexpr auto wei_kyxc_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
constexpr auto out_nhwk_desc = make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
Tensor<float> in_nhwc(make_HostTensorDescriptor(in_nhwc_desc));
Tensor<float> wei_kyxc(make_HostTensorDescriptor(wei_kyxc_desc));
Tensor<float> out_nhwk(make_HostTensorDescriptor(out_nhwk_desc));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
};
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
};
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
#if 0
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t i = 0; i < nrepeat; ++i)
{
using GridwiseConvBwdData =
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nhwc_kyxc_nhwk<
GridSize,
BlockSize,
T,
T,
decltype(in_nhwc_desc),
decltype(wei_kyxc_desc),
decltype(out_nhwk_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmK,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
constexpr index_t gemm_k = gemm_sizes.At(2);
constexpr bool is_gemm_not_empty = gemm_k > 0;
// only compile and run if GEMM is no empty
static_if<is_gemm_not_empty>{}([&](auto fwd) {
launch_kernel(run_gridwise_operation<GridwiseConvBwdData,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__,
decltype(gemm_id)>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nhwc_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kyxc_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nhwk_device_buf.GetDeviceBuffer()),
fwd(gemm_id));
});
});
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nhwc_device_buf.FromDevice(in_nhwc.mData.data());
auto f_nhwc2nchw = [&](auto n, auto c, auto hi, auto wi) {
in_nchw(n, c, hi, wi) = in_nhwc(n, hi, wi, c);
};
make_ParallelTensorFunctor(f_nhwc2nchw, N, C, Hi, Wi)(std::thread::hardware_concurrency());
}
} // namespace launcher
......@@ -16,6 +16,7 @@
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
int main(int argc, char* argv[])
......@@ -156,7 +157,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 256;
......@@ -186,7 +187,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 256;
......@@ -250,6 +251,8 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nhwc_kyxc_nhwk
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment