Unverified Commit c5da0377 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Added bwd data v3r1 v4r1, tweaking v1 (#10)

* Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel
* Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels
* Tweaked v1r1  and v1r2 (atomic) on AMD GPU
parent e2b4c5b4
...@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -122,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -122,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM: atomic add // GEMM
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
...@@ -131,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -131,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
decltype(wei_k_e_global_desc), decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc), decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add, in_memory_op,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
......
...@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
} }
} }
// input: register to global memory, atomic add
{ {
#if 1 // debug
// input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
#else
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
#endif
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t E0 = E / E1; constexpr index_t E0 = E / E1;
...@@ -378,8 +386,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -378,8 +386,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{}, make_tuple(UnMerge<Sequence<N0, N1>>{},
UnMerge<Sequence<C0, C1>>{}, UnMerge<Sequence<C0, C1>>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hi + LeftPads::At(0) + RightPads::At(0),
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
...@@ -422,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -422,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B, InThreadCopyDstDataPerWrite_B,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::global, AddressSpace::global,
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0}, in_memory_op>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1, {e_thread_data_on_global / E1,
e_thread_data_on_global % E1, e_thread_data_on_global % E1,
0, 0,
b_thread_data_on_global / B1, b_thread_data_on_global / B1,
b_thread_data_on_global % B1, b_thread_data_on_global % B1,
0}) 0})
.Run(p_in_thread, p_in_global); .Run(p_in_thread, p_in_global);
} }
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck { namespace ck {
// GemmM = C * Ytilda * Xtilda; // GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda; // GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * Ydot * Xdot; // GemmK = K * Ydot * Xdot;
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -20,8 +20,8 @@ template <index_t GridSize, ...@@ -20,8 +20,8 @@ template <index_t GridSize,
typename OutGlobalDesc, typename OutGlobalDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename InputLeftPads, typename InLeftPads,
typename InputRightPads, typename InRightPads,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data // TODO: this logic may not be correct for bwd-data
static_assert( static_assert(
...@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
...@@ -88,30 +90,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -88,30 +90,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda); constexpr index_t Htilda =
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t Htilda = Ho + right_pad_ho; constexpr index_t HtildaLeft = math::integer_divide_floor(
constexpr index_t Wtilda = Wo + right_pad_wo; math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
// weight tensor constexpr index_t HtildaRight = math::min(
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor( Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
wei_k_c_y_x_global_desc, constexpr index_t WtildaRight = math::min(
make_tuple(PassThrough<K>{}, Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc, wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>, Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{}, Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>, Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}), Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -123,55 +129,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -123,55 +129,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc, out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>, Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{}, Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>, Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}), Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
out_n_k_ydot_htilda_xdot_wtilda_global_desc, transform_tensor_descriptor(
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}), out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(PassThrough<N>{},
make_tuple(Sequence<0>{}, Sequence<1>{})); PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor // input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(
PassThrough<C>{}, PassThrough<N>{},
Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}), PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hip,
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, transform_tensor_descriptor(
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}), in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(PassThrough<N>{},
make_tuple(Sequence<0>{}, Sequence<1>{})); PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM // GEMM
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
......
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Ytilda*Xtilda number of GEMMs
// GemmM = C;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * YdotNonZero * XdotNonZero;
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 1 // debug
constexpr bool out_skip_all_out_of_bound_check = false;
#else
constexpr bool out_skip_all_out_of_bound_check = true;
#endif
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMMs
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
#if 1 // debug
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
#else
static_for<0, 1, 1>{}([&](auto ytilda_) {
static_for<0, 1, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
// is synchronization necessary?
__syncthreads();
});
});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = C
// GemmN = N * Htilda * Wtilda;
// GemmK = K * YdotNonZero * XdotNonZero
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t Iter_ytilda,
index_t Iter_xtilda,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 1 // debug
constexpr bool out_skip_all_out_of_bound_check = false;
#else
constexpr bool out_skip_all_out_of_bound_check = true;
#endif
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMM
constexpr index_t ytilda = Iter_ytilda;
constexpr index_t xtilda = Iter_xtilda;
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
// A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
...@@ -181,12 +181,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -181,12 +181,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{}, make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
......
...@@ -97,12 +97,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -97,12 +97,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
...@@ -41,15 +41,17 @@ struct PassThrough ...@@ -41,15 +41,17 @@ struct PassThrough
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return true; return true;
} }
}; };
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftPads, typename RightPads> template <typename LowerLengths,
typename LeftPads,
typename RightPads,
bool SkipIsValidCheck = false>
struct Pad struct Pad
{ {
static constexpr index_t nDim = LowerLengths::Size(); static constexpr index_t nDim = LowerLengths::Size();
...@@ -57,6 +59,13 @@ struct Pad ...@@ -57,6 +59,13 @@ struct Pad
using LowerIndex = MultiIndex<nDim>; using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>; using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Pad()
{
static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim &&
RightPads::GetSize() == nDim,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
...@@ -81,17 +90,80 @@ struct Pad ...@@ -81,17 +90,80 @@ struct Pad
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ constexpr bool __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{ {
#if 1 // debug
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true; bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0;
}
return flag;
}
};
// LowerLengths: Sequence<...>
// SliceBegins: Sequence<...>
// SliceEnds: Sequence<...>
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
struct Slice
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Slice()
{
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
SliceEnds::GetSize() == nDim,
"wrong! # of dimensions not consistent");
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
flag = flag && (idx_up[idim] >= LeftPads::At(idim)) && static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
(idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim)); SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
}); });
#endif
}
return flag; __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return SliceEnds{} - SliceBegins{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up + SliceBegins{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
} }
}; };
...@@ -165,85 +237,101 @@ struct Merge ...@@ -165,85 +237,101 @@ struct Merge
const UpperIndex& /* idx_up_old */, const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old) const LowerIndex& idx_low_old)
{ {
// do nothing if idx_up_diff == 0
if(idx_up_diff[0] == 0) if(idx_up_diff[0] == 0)
{ {
return make_zero_array<index_t, nDimLow>(); return make_zero_array<index_t, nDimLow>();
} }
else
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
if(idx_up_diff[0] > 0)
{ {
bool carry = false; // CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// do carry check in reversed order, starting from lowest dimension // be done at compile-time. However, if idx_up_diff is only known
// don't check the highest dimension // at run-time, then the calculation will also be computed at
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) { // run-time, and can be very expensive.
constexpr index_t i = nDimLow - 1 - ireverse; LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
if(carry) // find out the last low dimension that changed
index_t last_changed_low_dim = 0;
static_for<0, nDimLow, 1>{}([&](auto i) {
if(idx_low_diff_tmp[i] != 0)
{ {
++idx_low_new(i); last_changed_low_dim = i;
} }
});
carry = false; LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
if(idx_low_new[i] >= LowerLengths::At(i)) if(idx_up_diff[0] > 0)
{
// do carry check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool carry = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
}
});
// highest dimension, no out-of-bound check
if(carry)
{ {
idx_low_new(i) -= LowerLengths::At(i); ++idx_low_new(0);
carry = true;
} }
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
} }
} else
else if(idx_up_diff[0] < 0) {
{ // do borrow check on each low dimension in reversed order
bool borrow = false; // starting from the first digit that changed
// don't check the highest dimension
// do borrow check in reversed order, starting from lowest dimension bool borrow = false;
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) { static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
constexpr index_t i = nDimLow - 1 - ireverse; if(i <= last_changed_low_dim)
{
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
}
});
// highest dimension, no out-of-bound check
if(borrow) if(borrow)
{ {
--idx_low_new(i); --idx_low_new(0);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
} }
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
} }
}
return idx_low_new - idx_low_old; return idx_low_new - idx_low_old;
}
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return true; return true;
} }
...@@ -290,8 +378,7 @@ struct UnMerge ...@@ -290,8 +378,7 @@ struct UnMerge
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return true; return true;
} }
...@@ -300,7 +387,10 @@ struct UnMerge ...@@ -300,7 +387,10 @@ struct UnMerge
// UpperLengths: Sequence<...> // UpperLengths: Sequence<...>
// Coefficients: Sequence<...> // Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <typename UpperLengths, typename Coefficients> template <index_t LowerLength,
typename UpperLengths,
typename Coefficients,
bool SkipIsValidCheck = false>
struct Embed struct Embed
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
...@@ -325,8 +415,10 @@ struct Embed ...@@ -325,8 +415,10 @@ struct Embed
{ {
LowerIndex idx_low(Coefficients{}[nDimUp]); LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}( for(index_t i = 0; i < nDimUp; ++i)
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; }); {
idx_low(0) += idx_up[i] * Coefficients{}[i];
}
return idx_low; return idx_low;
} }
...@@ -338,18 +430,55 @@ struct Embed ...@@ -338,18 +430,55 @@ struct Embed
{ {
LowerIndex idx_low_diff{0}; LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}( for(index_t i = 0; i < nDimUp; ++i)
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; }); {
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
}
return idx_low_diff; return idx_low_diff;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return true; #if 1 // debug
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
index_t ncorner = 1;
for(index_t idim = 0; idim < nDimUp; ++idim)
{
ncorner *= 2;
}
// loop over each corner of the upper tensor
for(index_t icorner = 0; icorner < ncorner; ++icorner)
{
// generate upper index for each corner
auto idx_up = make_zero_array<index_t, nDimUp>();
index_t itmp = icorner;
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
{
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
itmp /= 2;
}
// calculate lower index
auto idx_low = CalculateLowerIndex(idx_up);
// judge if lower index is valid
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength;
}
return flag;
} }
}; };
...@@ -389,8 +518,7 @@ struct Vectorize ...@@ -389,8 +518,7 @@ struct Vectorize
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return true; return true;
} }
......
...@@ -53,6 +53,8 @@ struct NativeTensorCoordinate ...@@ -53,6 +53,8 @@ struct NativeTensorCoordinate
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; } __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; }
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; } __host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
__host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; } __host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
...@@ -98,7 +100,24 @@ struct NativeTensorCoordinate ...@@ -98,7 +100,24 @@ struct NativeTensorCoordinate
return tensor_desc_type::CalculateOffsetDiff(idx_diff); return tensor_desc_type::CalculateOffsetDiff(idx_diff);
} }
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; } // evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluated at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
// For native tensor, offset is valid if upper-index is valid
return IsUpperIndexValid();
}
// evaluated at compile-time
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
{
return true;
}
private: private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may // mIndex may be saved and updated, however, the value of some (or all) of its entries may
...@@ -206,10 +225,30 @@ struct TransformedTensorCoordinate ...@@ -206,10 +225,30 @@ struct TransformedTensorCoordinate
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff); return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
} }
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const // evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluted at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
{
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(
GetLowerCoordinate().GetIndex());
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
{ {
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) && return IsLowerIndexValidAssumingUpperIndexIsValid() &&
mCoordLow.IsUpperIndexMappedToValidOffset(); GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
} }
private: private:
......
...@@ -120,10 +120,17 @@ struct NativeTensorDescriptor ...@@ -120,10 +120,17 @@ struct NativeTensorDescriptor
return Tuple<>{}; return Tuple<>{};
} }
__host__ __device__ static constexpr bool // a multi-index is valid if there is a corresponding point for it in the tensor
IsUpperIndexMappedToValidOffset(const Index& /* idx */) __host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
{ {
return true; bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
}
return flag;
} }
}; };
...@@ -467,33 +474,49 @@ struct TransformedTensorDescriptor ...@@ -467,33 +474,49 @@ struct TransformedTensorDescriptor
} }
#endif #endif
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
{
bool flag = true;
for(index_t i = 0; i < nDimUp; ++i)
{
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
}
return flag;
}
// this function is for optimization purpose, it's called by tensor coordinate
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low)
{ {
bool flag = true; bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) { static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran); constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran)); // check a indtransformation if it does not always has a valid mapping
constexpr bool is_valid_up_always_mapped_to_valid_low =
flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part)); decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
if(!is_valid_up_always_mapped_to_valid_low)
{
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
constexpr auto low_lengths_part =
GetLowerTensorDescriptor().GetLengths(low_dims_part);
const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part));
for(index_t i = 0; i < low_dims_part.Size(); ++i)
{
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
}
}
}); });
return flag; return flag;
} }
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
// be less expensive.
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
{
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
CalculateLowerIndex(idx_up));
}
}; };
} // namespace ck } // namespace ck
......
...@@ -50,9 +50,37 @@ template <index_t GridSize, ...@@ -50,9 +50,37 @@ template <index_t GridSize,
index_t CThreadCopyDstDataPerWrite> index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1 struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
...@@ -64,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -64,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr auto M = a_k_m_global_desc.GetLengths()[1]; constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1]; constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N, BBlockCopyDstDataPerWrite_N,
...@@ -184,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -184,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr index_t b_block_space = constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_a_block_double[2 * a_block_space]; Float* p_a_block_double = p_shared_block;
__shared__ Float p_b_block_double[2 * b_block_space]; Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
...@@ -329,6 +363,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -329,6 +363,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
.Run(p_c_thread, p_c_global); .Run(p_c_thread, p_c_global);
} }
} }
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
}
}; };
} // namespace ck } // namespace ck
......
...@@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src data's valid mapping situation, only check the first data in this src // Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
...@@ -142,7 +142,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -142,7 +142,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check dst data's valid mapping situation, only check the first data in this dst // Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<DstData, move_data<DstData,
DstDataPerWrite, DstDataPerWrite,
...@@ -260,7 +260,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -260,7 +260,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src // src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
...@@ -299,7 +299,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -299,7 +299,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst // dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<DstData, move_data<DstData,
DstDataPerWrite, DstDataPerWrite,
...@@ -399,7 +399,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -399,7 +399,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src // src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
...@@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst // dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<DstData, move_data<DstData,
DstDataPerWrite, DstDataPerWrite,
......
...@@ -54,6 +54,13 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata, ...@@ -54,6 +54,13 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
bool glc, bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
__device__ void
__llvm_amdgcn_buffer_atomic_add(float vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
// buffer_load requires: // buffer_load requires:
// 1) p_src must be in global memory space, d_dst must be vgpr // 1) p_src must be in global memory space, d_dst must be vgpr
// 2) p_src to be a block-invariant pointer. // 2) p_src to be a block-invariant pointer.
...@@ -73,6 +80,13 @@ amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType ...@@ -73,6 +80,13 @@ amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset); index_t dst_const_data_offset);
template <typename T, index_t VectorSize>
__device__ void
amd_intrinsic_buffer_atomic_add(const typename vector_type<T, VectorSize>::MemoryType& src,
T* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset);
template <> template <>
__device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block, __device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
...@@ -289,5 +303,31 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src, ...@@ -289,5 +303,31 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
#endif #endif
} }
template <>
__device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
BufferLoadStoreDwordConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
__llvm_amdgcn_buffer_atomic_add(
src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
#else
static_assert(false, " wrong! not implemented");
#endif
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -29,6 +29,11 @@ ...@@ -29,6 +29,11 @@
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1 #define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1
#endif #endif
// only support gfx908
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0
#endif
// AMD XDLOPS // AMD XDLOPS
#ifndef CK_USE_AMD_XDLOPS #ifndef CK_USE_AMD_XDLOPS
#define CK_USE_AMD_XDLOPS 0 #define CK_USE_AMD_XDLOPS 0
......
...@@ -29,9 +29,10 @@ struct static_for ...@@ -29,9 +29,10 @@ struct static_for
{ {
__host__ __device__ constexpr static_for() __host__ __device__ constexpr static_for()
{ {
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should have NBegin <= NEnd");
} }
template <class F> template <class F>
......
...@@ -52,8 +52,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in ...@@ -52,8 +52,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}( static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
[&](auto) { [&](auto) {
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
#else
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]), atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset])); *reinterpret_cast<const vector_t*>(&p_src[src_offset]));
#endif
}) })
.Else([&](auto fwd) { .Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space"); static_assert(fwd(false), "atomic_add doesn't support this memory space");
......
...@@ -49,6 +49,12 @@ struct integer_divide_ceiler ...@@ -49,6 +49,12 @@ struct integer_divide_ceiler
} }
}; };
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <class X, class Y> template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{ {
......
...@@ -24,9 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA") ...@@ -24,9 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu) set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu)
endif() endif()
add_executable(conv ${CONV_SOURCE}) add_executable(conv_driver ${CONV_SOURCE})
add_executable(col2im ${COL2IM_SOURCE}) add_executable(col2im_driver ${COL2IM_SOURCE})
add_executable(conv_bwd_data ${CONV_BWD_DATA_SOURCE}) add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
target_link_libraries(conv PRIVATE host) target_link_libraries(conv_driver PRIVATE host)
target_link_libraries(col2im PRIVATE host) target_link_libraries(col2im_driver PRIVATE host)
target_link_libraries(conv_bwd_data PRIVATE host) target_link_libraries(conv_bwd_data_driver PRIVATE host)
...@@ -30,33 +30,81 @@ struct KernelTimer ...@@ -30,33 +30,81 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl; std::unique_ptr<KernelTimerImpl> impl;
}; };
#if CK_DEVICE_BACKEND_AMD
using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
template <typename... Args, typename F>
float launch_and_time_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{ {
KernelTimer timer; KernelTimer timer;
#if CK_DEVICE_BACKEND_AMD
timer.Start(); timer.Start();
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, 0, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
timer.End(); timer.End();
hipGetErrorString(hipGetLastError()); hipGetErrorString(hipGetLastError());
return timer.GetElapsedTime();
}
#elif CK_DEVICE_BACKEND_NVIDIA #elif CK_DEVICE_BACKEND_NVIDIA
using device_stream_t = cudaStream_t;
template <typename... Args, typename F>
void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
cudaStream_t stream_id,
Args... args)
{
const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...};
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
}
template <typename... Args, typename F>
float launch_and_time_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
cudaStream_t stream_id,
Args... args)
{
KernelTimer timer;
const void* f = reinterpret_cast<const void*>(kernel); const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...}; void* p_args[] = {&args...};
timer.Start(); timer.Start();
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, 0); cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
timer.End(); timer.End();
checkCudaErrors(error); checkCudaErrors(error);
#endif
return timer.GetElapsedTime(); return timer.GetElapsedTime();
} }
#endif
#endif #endif
...@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc, ...@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_col2im), float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_col2im),
const T* const __restrict__, const T* const __restrict__,
T* const __restrict__>, T* const __restrict__>,
dim3(GridSize), dim3(GridSize),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment