"vscode:/vscode.git/clone" did not exist on "86185bd7ce1b84696f064822e05837dd63e4f218"
Unverified Commit c5da0377 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Added bwd data v3r1 v4r1, tweaking v1 (#10)

* Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel
* Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels
* Tweaked v1r1  and v1r2 (atomic) on AMD GPU
parent e2b4c5b4
......@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -122,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM: atomic add
// GEMM
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
......@@ -131,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add,
in_memory_op,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......
......@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
}
}
// input: register to global memory, atomic add
{
#if 1 // debug
// input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
#else
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
#endif
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t E0 = E / E1;
......@@ -378,8 +386,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{},
UnMerge<Sequence<C0, C1>>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hi + LeftPads::At(0) + RightPads::At(0),
Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
......@@ -422,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B,
AddressSpace::vgpr,
AddressSpace::global,
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
0,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1,
0})
in_memory_op>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
0,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1,
0})
.Run(p_in_thread, p_in_global);
}
}
......
......@@ -9,7 +9,7 @@
namespace ck {
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * Ydot * Xdot;
template <index_t GridSize,
index_t BlockSize,
......@@ -20,8 +20,8 @@ template <index_t GridSize,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InputLeftPads,
typename InputRightPads,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
......@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
......@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
......@@ -88,30 +90,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
// weight tensor
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc,
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>,
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>,
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -123,55 +129,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc,
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>,
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}),
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
......
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Ytilda*Xtilda number of GEMMs
// GemmM = C;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * YdotNonZero * XdotNonZero;
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 1 // debug
constexpr bool out_skip_all_out_of_bound_check = false;
#else
constexpr bool out_skip_all_out_of_bound_check = true;
#endif
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMMs
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
#if 1 // debug
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
#else
static_for<0, 1, 1>{}([&](auto ytilda_) {
static_for<0, 1, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
// is synchronization necessary?
__syncthreads();
});
});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = C
// GemmN = N * Htilda * Wtilda;
// GemmK = K * YdotNonZero * XdotNonZero
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t Iter_ytilda,
index_t Iter_xtilda,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 1 // debug
constexpr bool out_skip_all_out_of_bound_check = false;
#else
constexpr bool out_skip_all_out_of_bound_check = true;
#endif
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMM
constexpr index_t ytilda = Iter_ytilda;
constexpr index_t xtilda = Iter_xtilda;
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
......@@ -181,12 +181,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
......
......@@ -97,12 +97,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
......@@ -41,15 +41,17 @@ struct PassThrough
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
// LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftPads, typename RightPads>
template <typename LowerLengths,
typename LeftPads,
typename RightPads,
bool SkipIsValidCheck = false>
struct Pad
{
static constexpr index_t nDim = LowerLengths::Size();
......@@ -57,6 +59,13 @@ struct Pad
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Pad()
{
static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim &&
RightPads::GetSize() == nDim,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
......@@ -81,17 +90,80 @@ struct Pad
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
#if 1 // debug
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0;
}
return flag;
}
};
// LowerLengths: Sequence<...>
// SliceBegins: Sequence<...>
// SliceEnds: Sequence<...>
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
struct Slice
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Slice()
{
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
SliceEnds::GetSize() == nDim,
"wrong! # of dimensions not consistent");
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) {
flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
(idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
});
#endif
}
return flag;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return SliceEnds{} - SliceBegins{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up + SliceBegins{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
......@@ -165,85 +237,101 @@ struct Merge
const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old)
{
// do nothing if idx_up_diff == 0
if(idx_up_diff[0] == 0)
{
return make_zero_array<index_t, nDimLow>();
}
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
if(idx_up_diff[0] > 0)
else
{
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(carry)
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
// find out the last low dimension that changed
index_t last_changed_low_dim = 0;
static_for<0, nDimLow, 1>{}([&](auto i) {
if(idx_low_diff_tmp[i] != 0)
{
++idx_low_new(i);
last_changed_low_dim = i;
}
});
carry = false;
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
if(idx_low_new[i] >= LowerLengths::At(i))
if(idx_up_diff[0] > 0)
{
// do carry check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool carry = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
}
});
// highest dimension, no out-of-bound check
if(carry)
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
++idx_low_new(0);
}
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
}
}
else if(idx_up_diff[0] < 0)
{
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
else
{
// do borrow check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool borrow = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
--idx_low_new(0);
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
}
}
return idx_low_new - idx_low_old;
return idx_low_new - idx_low_old;
}
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
......@@ -290,8 +378,7 @@ struct UnMerge
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
......@@ -300,7 +387,10 @@ struct UnMerge
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <typename UpperLengths, typename Coefficients>
template <index_t LowerLength,
typename UpperLengths,
typename Coefficients,
bool SkipIsValidCheck = false>
struct Embed
{
static constexpr index_t nDimLow = 1;
......@@ -325,8 +415,10 @@ struct Embed
{
LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
for(index_t i = 0; i < nDimUp; ++i)
{
idx_low(0) += idx_up[i] * Coefficients{}[i];
}
return idx_low;
}
......@@ -338,18 +430,55 @@ struct Embed
{
LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
for(index_t i = 0; i < nDimUp; ++i)
{
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
}
return idx_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
#if 1 // debug
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
index_t ncorner = 1;
for(index_t idim = 0; idim < nDimUp; ++idim)
{
ncorner *= 2;
}
// loop over each corner of the upper tensor
for(index_t icorner = 0; icorner < ncorner; ++icorner)
{
// generate upper index for each corner
auto idx_up = make_zero_array<index_t, nDimUp>();
index_t itmp = icorner;
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
{
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
itmp /= 2;
}
// calculate lower index
auto idx_low = CalculateLowerIndex(idx_up);
// judge if lower index is valid
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength;
}
return flag;
}
};
......@@ -389,8 +518,7 @@ struct Vectorize
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
......
......@@ -53,6 +53,8 @@ struct NativeTensorCoordinate
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; }
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
__host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
......@@ -98,7 +100,24 @@ struct NativeTensorCoordinate
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
}
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
// evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluated at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
// For native tensor, offset is valid if upper-index is valid
return IsUpperIndexValid();
}
// evaluated at compile-time
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
{
return true;
}
private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
......@@ -206,10 +225,30 @@ struct TransformedTensorCoordinate
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
}
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
// evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluted at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
{
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(
GetLowerCoordinate().GetIndex());
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
{
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
mCoordLow.IsUpperIndexMappedToValidOffset();
return IsLowerIndexValidAssumingUpperIndexIsValid() &&
GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
}
private:
......
......@@ -120,10 +120,17 @@ struct NativeTensorDescriptor
return Tuple<>{};
}
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
{
return true;
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
}
return flag;
}
};
......@@ -467,33 +474,49 @@ struct TransformedTensorDescriptor
}
#endif
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
{
bool flag = true;
for(index_t i = 0; i < nDimUp; ++i)
{
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
}
return flag;
}
// this function is for optimization purpose, it's called by tensor coordinate
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low)
{
bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part));
// check a indtransformation if it does not always has a valid mapping
constexpr bool is_valid_up_always_mapped_to_valid_low =
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
if(!is_valid_up_always_mapped_to_valid_low)
{
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
constexpr auto low_lengths_part =
GetLowerTensorDescriptor().GetLengths(low_dims_part);
const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part));
for(index_t i = 0; i < low_dims_part.Size(); ++i)
{
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
}
}
});
return flag;
}
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
// be less expensive.
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
{
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
CalculateLowerIndex(idx_up));
}
};
} // namespace ck
......
......@@ -50,9 +50,37 @@ template <index_t GridSize,
index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto True = integral_constant<bool, true>{};
......@@ -64,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
......@@ -184,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
......@@ -329,6 +363,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
.Run(p_c_thread, p_c_global);
}
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
}
};
} // namespace ck
......
......@@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<SrcData,
SrcDataPerRead,
......@@ -142,7 +142,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<DstData,
DstDataPerWrite,
......@@ -260,7 +260,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<SrcData,
SrcDataPerRead,
......@@ -299,7 +299,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<DstData,
DstDataPerWrite,
......@@ -399,7 +399,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<SrcData,
SrcDataPerRead,
......@@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<DstData,
DstDataPerWrite,
......
......@@ -54,6 +54,13 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
__device__ void
__llvm_amdgcn_buffer_atomic_add(float vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
// buffer_load requires:
// 1) p_src must be in global memory space, d_dst must be vgpr
// 2) p_src to be a block-invariant pointer.
......@@ -73,6 +80,13 @@ amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType
index_t dst_thread_data_offset,
index_t dst_const_data_offset);
template <typename T, index_t VectorSize>
__device__ void
amd_intrinsic_buffer_atomic_add(const typename vector_type<T, VectorSize>::MemoryType& src,
T* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset);
template <>
__device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
index_t src_thread_data_offset,
......@@ -289,5 +303,31 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
#endif
}
template <>
__device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
BufferLoadStoreDwordConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
__llvm_amdgcn_buffer_atomic_add(
src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
#else
static_assert(false, " wrong! not implemented");
#endif
}
} // namespace ck
#endif
......@@ -29,6 +29,11 @@
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1
#endif
// only support gfx908
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0
#endif
// AMD XDLOPS
#ifndef CK_USE_AMD_XDLOPS
#define CK_USE_AMD_XDLOPS 0
......
......@@ -29,9 +29,10 @@ struct static_for
{
__host__ __device__ constexpr static_for()
{
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert((NEnd - NBegin) % Increment == 0,
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should have NBegin <= NEnd");
}
template <class F>
......
......@@ -52,8 +52,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
[&](auto) {
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
#else
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
#endif
})
.Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
......
......@@ -49,6 +49,12 @@ struct integer_divide_ceiler
}
};
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{
......
......@@ -24,9 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu)
endif()
add_executable(conv ${CONV_SOURCE})
add_executable(col2im ${COL2IM_SOURCE})
add_executable(conv_bwd_data ${CONV_BWD_DATA_SOURCE})
target_link_libraries(conv PRIVATE host)
target_link_libraries(col2im PRIVATE host)
target_link_libraries(conv_bwd_data PRIVATE host)
add_executable(conv_driver ${CONV_SOURCE})
add_executable(col2im_driver ${COL2IM_SOURCE})
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
target_link_libraries(conv_driver PRIVATE host)
target_link_libraries(col2im_driver PRIVATE host)
target_link_libraries(conv_bwd_data_driver PRIVATE host)
......@@ -30,33 +30,81 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl;
};
#if CK_DEVICE_BACKEND_AMD
using device_stream_t = hipStream_t;
template <typename... Args, typename F>
float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
template <typename... Args, typename F>
float launch_and_time_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{
KernelTimer timer;
#if CK_DEVICE_BACKEND_AMD
timer.Start();
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, 0, args...);
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
timer.End();
hipGetErrorString(hipGetLastError());
return timer.GetElapsedTime();
}
#elif CK_DEVICE_BACKEND_NVIDIA
using device_stream_t = cudaStream_t;
template <typename... Args, typename F>
void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
cudaStream_t stream_id,
Args... args)
{
const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...};
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
}
template <typename... Args, typename F>
float launch_and_time_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
cudaStream_t stream_id,
Args... args)
{
KernelTimer timer;
const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...};
timer.Start();
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, 0);
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
timer.End();
checkCudaErrors(error);
#endif
return timer.GetElapsedTime();
}
#endif
#endif
......@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_col2im),
float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_col2im),
const T* const __restrict__,
T* const __restrict__>,
dim3(GridSize),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment