Commit e9c5efc4 authored by Chao Liu's avatar Chao Liu
Browse files

add bwd-data-v5r1-nhwc, refactored bwd-data-v4r1-nchw, remove obsolete kernels

parent fe7b2d9f
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = C * YTilda * XTilda;
// GemmN = N * HTildaSlice * WTildaSlice;
// GemmK = K * YDot * XDot;
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>>{},
Embed<X,
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(Merge<Sequence<K, YDot, XDot>>{}, Merge<Sequence<C, YTilda, XTilda>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{},
Embed<Wo,
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDot, XDot>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, YTilda, XTilda>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Number of GEMMs: YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 1 // debug
constexpr bool out_skip_all_out_of_bound_check = false;
#else
constexpr bool out_skip_all_out_of_bound_check = true;
#endif
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMMs
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
static_for<0, YTilda, 1>{}([&](auto iYTilda_) {
static_for<0, XTilda, 1>{}([&](auto iXTilda_) {
constexpr index_t iYTilda = decltype(iYTilda_){};
constexpr index_t iXTilda = decltype(iXTilda_){};
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
// A matrix
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<YDot, XDot>,
Sequence<0, 0>,
Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YDot, XDot>,
Sequence<0, 0>,
Sequence<YDotSlice, XDotSlice>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
// is synchronization necessary?
__syncthreads();
});
});
}
};
} // namespace ck
#endif
...@@ -217,23 +217,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -217,23 +217,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = constexpr auto wei_k_c_ydotslice_xdotslice_global_desc = transform_tensor_descriptor(
transform_tensor_descriptor( wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, make_tuple(
make_tuple( PassThrough<K>{},
PassThrough<K>{}, PassThrough<C>{},
PassThrough<C>{}, Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}, Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{}),
Slice<Sequence<YTilda, XTilda>, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
Sequence<iYTilda, iXTilda>, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<>{}));
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc, wei_k_c_ydotslice_xdotslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, Merge<Sequence<C, 1, 1>>{}), make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, PassThrough<C>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix: output tensor // B matrix: output tensor
...@@ -265,8 +262,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -265,8 +262,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
out_n_k_ydot_htilda_xdot_wtilda_global_desc, out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
PassThrough<YTilda>{}, PassThrough<YDot>{},
PassThrough<XTilda>{}, PassThrough<XDot>{},
Slice<Sequence<HTilda, WTilda>, Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>, Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{}), Sequence<iHTildaRight, iWTildaRight>>{}),
...@@ -331,40 +328,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -331,40 +328,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = constexpr auto in_n_c_htildaslice_wtildaslice_global_desc = transform_tensor_descriptor(
transform_tensor_descriptor( in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, make_tuple(PassThrough<N>{},
make_tuple(PassThrough<N>{}, PassThrough<C>{},
PassThrough<C>{}, Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
PassThrough<YTilda>{}, Slice<Sequence<HTilda, WTilda>,
PassThrough<XTilda>{}, Sequence<iHTildaLeft, iWTildaLeft>,
Slice<Sequence<HTilda, WTilda>, Sequence<iHTildaRight, iWTildaRight>>{}),
Sequence<iHTildaLeft, iWTildaLeft>, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
Sequence<iHTildaRight, iWTildaRight>>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<>{}, Sequence<2, 3>{}));
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, in_n_c_htildaslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}), make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
......
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
namespace ck { namespace ck {
// Number of GEMM partition: YTilda * XTilda // Number of GEMMs = YTilda * XTilda
// Number of GEMM iteration: YDotSlice * XDotSlice // GemmM = C
// GemmM = C // GemmN = N * HTildaSlice * WTildaSlice
// GemmN = N * HTildaSlice * WTildaSlice // GemmK0 = YDotSlice
// GemmK = K // GemmK1 = XDotSlice
// GemmK2 = K
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -42,10 +43,10 @@ template <index_t GridSize, ...@@ -42,10 +43,10 @@ template <index_t GridSize,
index_t GemmABlockCopyDstDataPerWrite_GemmM, index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, typename GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, typename GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN, index_t GemmBBlockCopySrcDataPerRead_GemmK2,
index_t GemmBBlockCopyDstDataPerWrite_GemmN, index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1> index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
{ {
__host__ __device__ static constexpr index_t GetNumberOfGemm() __host__ __device__ static constexpr index_t GetNumberOfGemm()
{ {
...@@ -67,16 +68,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -67,16 +68,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda) __host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
{ {
constexpr index_t N = InGlobalDesc::GetLengths()[0]; constexpr index_t N = InGlobalDesc::GetLengths()[0];
constexpr index_t C = InGlobalDesc::GetLengths()[1]; constexpr index_t Hi = InGlobalDesc::GetLengths()[1];
constexpr index_t Hi = InGlobalDesc::GetLengths()[2]; constexpr index_t Wi = InGlobalDesc::GetLengths()[2];
constexpr index_t Wi = InGlobalDesc::GetLengths()[3]; constexpr index_t C = InGlobalDesc::GetLengths()[3];
constexpr index_t K = OutGlobalDesc::GetLengths()[1]; constexpr index_t Ho = OutGlobalDesc::GetLengths()[1];
constexpr index_t Ho = OutGlobalDesc::GetLengths()[2]; constexpr index_t Wo = OutGlobalDesc::GetLengths()[2];
constexpr index_t Wo = OutGlobalDesc::GetLengths()[3]; constexpr index_t K = OutGlobalDesc::GetLengths()[3];
constexpr index_t Y = WeiGlobalDesc::GetLengths()[2]; constexpr index_t Y = WeiGlobalDesc::GetLengths()[1];
constexpr index_t X = WeiGlobalDesc::GetLengths()[3]; constexpr index_t X = WeiGlobalDesc::GetLengths()[2];
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
...@@ -120,9 +121,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -120,9 +121,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
index_t GemmK = K * YDotSlice * XDotSlice; index_t GemmK0 = YDotSlice;
index_t GemmK1 = XDotSlice;
index_t GemmK2 = K;
return Array<index_t, 3>{GemmM, GemmN, GemmK}; return Array<index_t, 5>{GemmM, GemmN, GemmK0, GemmK1, GemmK2};
} }
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
...@@ -146,21 +149,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -146,21 +149,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const Float* __restrict__ p_out_global)
{ {
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2];
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
...@@ -203,10 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -203,10 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
// weight out-of-bound check can be skipped // weight out-of-bound check can be skipped
constexpr bool wei_skip_out_of_bound_check = true; constexpr bool wei_skip_out_of_bound_check = true;
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc, wei_k_y_x_c_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y, Embed<Y,
Sequence<YDot, YTilda>, Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>, Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
...@@ -214,31 +216,24 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -214,31 +216,24 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
Embed<X, Embed<X,
Sequence<XDot, XTilda>, Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>, Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_out_of_bound_check>{}), wei_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor(
transform_tensor_descriptor( wei_k_ydot_ytilda_xdot_xtilda_c_global_desc,
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, make_tuple(
make_tuple( PassThrough<K>{},
PassThrough<K>{}, Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
PassThrough<C>{}, Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}, PassThrough<C>{}),
Slice<Sequence<YTilda, XTilda>, make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
Sequence<iYTilda, iXTilda>, make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{}));
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc =
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); reorder_tensor_descriptor_given_lower2upper(wei_k_ydotslice_xdotslice_c_global_desc,
Sequence<2, 0, 1, 3>{});
constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
make_tuple(PassThrough<YDotSlice>{},
PassThrough<XDotSlice>{},
PassThrough<K>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<2>{}, Sequence<4>{}, Sequence<0>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// B matrix: output tensor // B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such // TODO sometimes output tensor out-of-bound check can be skipped, find out all such
...@@ -249,10 +244,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -249,10 +244,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
constexpr bool out_skip_out_of_bound_check = true; constexpr bool out_skip_out_of_bound_check = true;
#endif #endif
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc, out_n_ho_wo_k_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho, Embed<Ho,
Sequence<YDot, HTilda>, Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
...@@ -260,46 +254,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -260,46 +254,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
Embed<Wo, Embed<Wo,
Sequence<XDot, WTilda>, Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_out_of_bound_check>{}), out_skip_out_of_bound_check>{},
PassThrough<K>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, out_n_ydot_htilda_xdot_wtilda_k_global_desc,
make_tuple( make_tuple(
PassThrough<N>{}, PassThrough<N>{},
PassThrough<K>{}, Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
PassThrough<HTildaSlice>{}, Slice<Sequence<HTilda, WTilda>,
PassThrough<WTildaSlice>{}, Sequence<iHTildaLeft, iWTildaLeft>,
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}), Sequence<iHTildaRight, iWTildaRight>>{},
make_tuple( PassThrough<K>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple( make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}));
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor( constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc,
make_tuple(PassThrough<YDotSlice>{}, make_tuple(PassThrough<YDotSlice>{},
PassThrough<XDotSlice>{}, PassThrough<XDotSlice>{},
PassThrough<K>{}, PassThrough<K>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}), Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<2>{}, Sequence<4>{}, Sequence<1>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// C matrix: input tensor // C matrix: input tensor
...@@ -310,22 +289,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -310,22 +289,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
constexpr bool in_skip_out_of_bound_check = true; constexpr bool in_skip_out_of_bound_check = true;
#endif #endif
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_hi_wi_c_global_desc,
make_tuple( make_tuple(PassThrough<N>{},
PassThrough<N>{}, Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{},
PassThrough<C>{}, PassThrough<C>{}),
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_hip_wip_c_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Embed<Hip,
Sequence<YTilda, HTilda>, Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>, Sequence<ConvDilationH, ConvStrideH, 0>,
...@@ -333,44 +310,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -333,44 +310,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
Embed<Wip, Embed<Wip,
Sequence<XTilda, WTilda>, Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>, Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_out_of_bound_check>{}), in_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor(
transform_tensor_descriptor( in_n_ytilda_htilda_xtilda_wtilda_c_global_desc,
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, make_tuple(PassThrough<N>{},
make_tuple(PassThrough<N>{}, Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
PassThrough<C>{}, Slice<Sequence<HTilda, WTilda>,
PassThrough<HTildaSlice>{}, Sequence<iHTildaLeft, iWTildaLeft>,
PassThrough<WTildaSlice>{}, Sequence<iHTildaRight, iWTildaRight>>{},
Slice<Sequence<YTilda, XTilda>, PassThrough<C>{}),
Sequence<iYTilda, iXTilda>, make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
Sequence<iYTilda + 1, iXTilda + 1>>{}), make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{}));
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, in_n_htildaslice_wtildaslice_c_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}), make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// call GEMM // call GEMM
...@@ -404,12 +363,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw ...@@ -404,12 +363,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
GemmABlockCopyDstDataPerWrite_GemmM, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 3, 2>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 3, 2>,
3, 2,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmK2,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>{};
......
...@@ -488,6 +488,49 @@ struct Embed ...@@ -488,6 +488,49 @@ struct Embed
} }
}; };
// LowerLengths: Sequence<...>
// LowerFreezePoint: Sequence<...>
template <typename LowerLengths, typename LowerFreezePoint>
struct Freeze
{
static constexpr index_t nDimLow = LowerLengths::Size();
static constexpr index_t nDimUp = 0;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ explicit constexpr Freeze()
{
// TODO: sanity check: LowerFreezePoint should be within range of LowerLengths
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<0>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<>{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/)
{
return to_array(LowerFreezePoint{});
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& /* idx_up_diff */,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return make_zero_array<index_t, nDimLow>();
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
template <index_t LowerLength, index_t VectorSize> template <index_t LowerLength, index_t VectorSize>
struct Vectorize struct Vectorize
{ {
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr index_t GemmM = C * YTilda * XTilda;
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
...@@ -57,8 +57,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -57,8 +57,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, each thread hold 64 data // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -86,6 +86,36 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -86,6 +86,36 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
namespace launcher { namespace launcher {
...@@ -17,7 +17,7 @@ template <typename T, ...@@ -17,7 +17,7 @@ template <typename T,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc in_nchw_desc,
Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc, WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
...@@ -48,17 +48,41 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i ...@@ -48,17 +48,41 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr auto in_nhwc_desc = make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
constexpr auto wei_kyxc_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
constexpr auto out_nhwk_desc = make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
Tensor<float> in_nhwc(make_HostTensorDescriptor(in_nhwc_desc));
Tensor<float> wei_kyxc(make_HostTensorDescriptor(wei_kyxc_desc));
Tensor<float> out_nhwk(make_HostTensorDescriptor(out_nhwk_desc));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
};
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
};
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data()); in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
#if 1 #if 0
// BlockSize = 256, each thread hold 64 data // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -74,16 +98,46 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i ...@@ -74,16 +98,46 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 4, 1>; using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 2, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
...@@ -132,14 +186,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i ...@@ -132,14 +186,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
using GridwiseConvBwdData = using GridwiseConvBwdData =
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw< GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk<
GridSize, GridSize,
BlockSize, BlockSize,
T, T,
T, T,
decltype(in_nchw_desc), decltype(in_nhwc_desc),
decltype(wei_kcyx_desc), decltype(wei_kyxc_desc),
decltype(out_nkhw_desc), decltype(out_nhwk_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
...@@ -162,14 +216,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i ...@@ -162,14 +216,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
GemmABlockCopyDstDataPerWrite_GemmM, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmK2,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>; GemmCThreadCopyDstDataPerWrite_GemmN1>;
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
constexpr index_t gemm_k = gemm_sizes.At(2); constexpr index_t gemm_k2 = gemm_sizes.At(4);
constexpr bool is_gemm_not_empty = gemm_k > 0; constexpr bool is_gemm_not_empty = gemm_k2 > 0;
// only compile and run if GEMM is no empty // only compile and run if GEMM is no empty
static_if<is_gemm_not_empty>{}([&](auto fwd) { static_if<is_gemm_not_empty>{}([&](auto fwd) {
...@@ -182,9 +236,9 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i ...@@ -182,9 +236,9 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nhwc_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kyxc_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()), static_cast<T*>(out_nhwk_device_buf.GetDeviceBuffer()),
fwd(gemm_id)); fwd(gemm_id));
}); });
}); });
...@@ -200,7 +254,13 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i ...@@ -200,7 +254,13 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
} }
in_nchw_device_buf.FromDevice(in_nchw.mData.data()); in_nhwc_device_buf.FromDevice(in_nhwc.mData.data());
auto f_nhwc2nchw = [&](auto n, auto c, auto hi, auto wi) {
in_nchw(n, c, hi, wi) = in_nhwc(n, hi, wi, c);
};
make_ParallelTensorFunctor(f_nhwc2nchw, N, C, Hi, Wi)(std::thread::hardware_concurrency());
} }
} // namespace launcher } // namespace launcher
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
#include "host_conv_bwd_data.hpp" #include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -56,7 +54,7 @@ int main(int argc, char* argv[]) ...@@ -56,7 +54,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 256;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 1024; constexpr index_t K = 1024;
...@@ -161,7 +159,7 @@ int main(int argc, char* argv[]) ...@@ -161,7 +159,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 256;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 1024;
...@@ -173,10 +171,10 @@ int main(int argc, char* argv[]) ...@@ -173,10 +171,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 1 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 256;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 1024;
...@@ -188,13 +186,13 @@ int main(int argc, char* argv[]) ...@@ -188,13 +186,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 256;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 128; constexpr index_t K = 1280;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -251,13 +249,9 @@ int main(int argc, char* argv[]) ...@@ -251,13 +249,9 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment