Unverified Commit e2b4c5b4 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

update implicit GEMM forward v4r4 to use gridwise gemm (#9)

* updated fwd v4r4 to use gridwise gemm
* updated gridwise gemm api calls in bwd-data v1r1 and v2r1
parent 19a93dac
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
namespace ck { namespace ck {
// GemmM = C * Y * X
// GemmN = N * Ho * Wo
// GemmK = K
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -17,11 +20,11 @@ template <index_t GridSize, ...@@ -17,11 +20,11 @@ template <index_t GridSize,
typename OutGlobalDesc, typename OutGlobalDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InLeftPads,
typename RightPads, typename InRightPads,
index_t EPerBlock, index_t GemmMPerBlock,
index_t BPerBlock, index_t GemmNPerBlock,
index_t KPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
...@@ -31,13 +34,15 @@ template <index_t GridSize, ...@@ -31,13 +34,15 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM, index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN, index_t GemmThreadGemmDataPerReadN,
typename WeiBlockCopySubLengths_K_E, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename WeiBlockCopyClusterLengths_K_E, typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t WeiBlockCopyDataPerAccess_E, index_t GemmABlockCopySrcDataPerRead_GemmN,
typename OutBlockCopySubLengths_K_B, index_t GemmABlockCopyDstDataPerWrite_GemmN,
typename OutBlockCopyClusterLengths_K_B, typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
index_t OutBlockCopyDataPerAccess_B, typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t InThreadCopyDataPerAccess_B> index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{ {
__device__ void Run(Float* __restrict__ p_in_global, __device__ void Run(Float* __restrict__ p_in_global,
...@@ -49,8 +54,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -49,8 +54,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
...@@ -73,12 +76,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -73,12 +76,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) && // TODO: this logic may not be correct for bwd-data
(X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0), static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
...@@ -99,8 +101,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -99,8 +101,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// input tensor // input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple( make_tuple(PassThrough<N>{},
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}), PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
...@@ -121,7 +124,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -121,7 +124,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// GEMM: atomic add // GEMM: atomic add
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
...@@ -129,9 +132,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -129,9 +132,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
decltype(out_k_b_global_desc), decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add, InMemoryDataOperation::atomic_add,
EPerBlock, GemmMPerBlock,
BPerBlock, GemmNPerBlock,
KPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
...@@ -141,13 +144,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -141,13 +144,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
WeiBlockCopySubLengths_K_E, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
WeiBlockCopyClusterLengths_K_E, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
WeiBlockCopyDataPerAccess_E, Sequence<0, 1>,
OutBlockCopySubLengths_K_B, Sequence<0, 1>,
OutBlockCopyClusterLengths_K_B, 1,
OutBlockCopyDataPerAccess_B, GemmABlockCopySrcDataPerRead_GemmN,
InThreadCopyDataPerAccess_B>{}; GemmABlockCopyDstDataPerWrite_GemmN,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
} }
......
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
namespace ck { namespace ck {
// GemmK = K * Ydot * Xdot;
// GemmM = C * Ytilda * Xtilda; // GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda; // GemmN = N * Htilda * Wtilda;
// GemmK = K * Ydot * Xdot;
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -34,14 +34,15 @@ template <index_t GridSize, ...@@ -34,14 +34,15 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM, index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN, index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopyDataPerAccess, // Gemm-M index_t GemmABlockCopySrcDataPerRead_GemmM,
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
index_t GemmBBlockCopyDataPerAccess, // Gemm-N typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmCThreadCopyDataPerAccess // Gemm-N index_t GemmBBlockCopySrcDataPerRead_GemmN,
> index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{ {
__device__ void Run(Float* __restrict__ p_in_global, __device__ void Run(Float* __restrict__ p_in_global,
...@@ -71,8 +72,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -71,8 +72,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) && // TODO: this logic may not be correct for bwd-data
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0), static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
...@@ -172,7 +175,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -172,7 +175,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
// GEMM // GEMM
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
...@@ -192,13 +195,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -192,13 +195,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyClusterLengths, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopyDataPerAccess, Sequence<0, 1>,
GemmBBlockCopySubLengths, Sequence<0, 1>,
GemmBBlockCopyClusterLengths, 1,
GemmBBlockCopyDataPerAccess, GemmABlockCopySrcDataPerRead_GemmM,
GemmCThreadCopyDataPerAccess>{}; GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
} }
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 &&
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// weight tensor
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_e_k_global_desc),
decltype(in_e_b_global_desc),
decltype(out_k_b_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc =
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment
constexpr auto in_e_b_block_desc =
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// global mem
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// LDS
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// weight blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat =
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(c_k0k1_b0b1_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store last data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
}
}
// copy output: register to global memory
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
// src descriptor
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
// dst descriptor
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t K0 = K / K1;
constexpr index_t B0 = B / B1;
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
out_k_b_global_desc,
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B,
AddressSpace::vgpr,
AddressSpace::global>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1})
.Run(p_out_thread, p_out_global);
}
}
};
} // namespace ck
#endif
...@@ -9,6 +9,12 @@ ...@@ -9,6 +9,12 @@
namespace ck { namespace ck {
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <index_t BlockSize, template <index_t BlockSize,
typename BlockSrcDesc, typename BlockSrcDesc,
typename BlockDstDesc, typename BlockDstDesc,
...@@ -18,10 +24,10 @@ template <index_t BlockSize, ...@@ -18,10 +24,10 @@ template <index_t BlockSize,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder, typename SrcDimAccessOrder,
typename DstDimAccessOrder, typename DstDimAccessOrder,
index_t SrcVectorAccessDim, index_t SrcVectoReadDim,
index_t DstVectorAccessDim, index_t DstVectorWriteDim,
index_t SrcDataPerAccess, index_t SrcDataPerRead,
index_t DstDataPerAccess, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic, AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic, AddressSpace DstAddressSpace = AddressSpace::generic,
...@@ -146,8 +152,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -146,8 +152,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadBufferDesc, ThreadBufferDesc,
ThreadSliceLengths, ThreadSliceLengths,
SrcDimAccessOrder, SrcDimAccessOrder,
SrcVectorAccessDim, SrcVectoReadDim,
SrcDataPerAccess, SrcDataPerRead,
1, 1,
SrcAddressSpace, SrcAddressSpace,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
...@@ -157,9 +163,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -157,9 +163,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockDstDesc, BlockDstDesc,
ThreadSliceLengths, ThreadSliceLengths,
DstDimAccessOrder, DstDimAccessOrder,
DstVectorAccessDim, DstVectorWriteDim,
1, 1,
DstDataPerAccess, DstDataPerWrite,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
DstAddressSpace, DstAddressSpace,
DstInMemOp>; DstInMemOp>;
......
...@@ -31,14 +31,24 @@ template <index_t GridSize, ...@@ -31,14 +31,24 @@ template <index_t GridSize,
index_t KPerThreadLoop, index_t KPerThreadLoop,
index_t ThreadGemmDataPerReadM, index_t ThreadGemmDataPerReadM,
index_t ThreadGemmDataPerReadN, index_t ThreadGemmDataPerReadN,
typename ABlockCopySubLengths_K_M, typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyClusterLengths_K_M, typename ABlockCopyThreadClusterLengths_K_M,
index_t ABlockCopyDataPerAccess_M, typename ABlockCopyThreadClusterArrangeOrder,
typename BBlockCopySubLengths_K_N, typename ABlockCopySrcAccessOrder,
typename BBlockCopyClusterLengths_K_N, index_t ABlockCopySrcVectorReadDim,
index_t BBlockCopyDataPerAccess_N, index_t ABlockCopySrcDataPerRead,
index_t CThreadCopyDataPerAccess_N> index_t ABlockCopyDstDataPerWrite_M,
struct GridwiseGemmTransposedANormalBNormalC_v1r1 typename BBlockCopyThreadSliceLengths_K_N,
typename BBlockCopyThreadClusterLengths_K_N,
typename BBlockCopyThreadClusterArrangeOrder,
typename BBlockCopySrcAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
typename CThreadCopySrcDstAccessOrder,
index_t CThreadCopySrcDstVectorReadWriteDim,
index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
...@@ -55,8 +65,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -55,8 +65,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
constexpr auto N = b_k_n_global_desc.GetLengths()[1]; constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDataPerAccess_N, BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM, ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN); ThreadGemmDataPerReadN);
...@@ -86,15 +96,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -86,15 +96,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()), decltype(a_k_m_block_desc.GetLengths()),
ABlockCopySubLengths_K_M, ABlockCopyThreadSliceLengths_K_M,
ABlockCopyClusterLengths_K_M, ABlockCopyThreadClusterLengths_K_M,
Sequence<0, 1>, ABlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>, ABlockCopySrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockCopySrcVectorReadDim,
1, 1,
1, ABlockCopySrcDataPerRead,
ABlockCopyDataPerAccess_M, ABlockCopyDstDataPerWrite_M,
ABlockCopyDataPerAccess_M,
AddressSpace::global, AddressSpace::global,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::lds, AddressSpace::lds,
...@@ -112,15 +122,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -112,15 +122,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()), decltype(b_k_n_block_desc.GetLengths()),
BBlockCopySubLengths_K_N, BBlockCopyThreadSliceLengths_K_N,
BBlockCopyClusterLengths_K_N, BBlockCopyThreadClusterLengths_K_N,
Sequence<0, 1>, BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, BBlockCopySrcVectorReadDim,
1,
1, 1,
BBlockCopyDataPerAccess_N, BBlockCopySrcDataPerRead,
BBlockCopyDataPerAccess_N, BBlockCopyDstDataPerWrite_N,
AddressSpace::global, AddressSpace::global,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::lds, AddressSpace::lds,
...@@ -304,10 +314,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -304,10 +314,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc), ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()), decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3>, CThreadCopySrcDstAccessOrder,
3, CThreadCopySrcDstVectorReadWriteDim,
CThreadCopyDataPerAccess_N, 1,
CThreadCopyDataPerAccess_N, CThreadCopyDstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::global, AddressSpace::global,
CGlobalMemoryDataOperation>( CGlobalMemoryDataOperation>(
......
...@@ -8,20 +8,19 @@ ...@@ -8,20 +8,19 @@
namespace ck { namespace ck {
// This version use multi-index transformation
// This threadwise copy allow vector access of src and dst. // This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst. // It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst. // The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst. // The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// the other is device memory or LDS // Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <typename SrcDesc, template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename SrcDstDimAccessOrder,
index_t VectorAccessDim, index_t SrcDstVectorReadWriteDim,
index_t SrcDataPerAccess, index_t SrcDataPerRead,
index_t DstDataPerAccess, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic, AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none> InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
...@@ -39,16 +38,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -39,16 +38,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
static_assert(nDim == SrcDesc::GetNumOfDimension() && static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() && nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
nDim == DimAccessOrder::Size(), nDim == SrcDstDimAccessOrder::Size(),
"wrong! # of dimensions not the same"); "wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid"); static_assert(is_valid_sequence_map<SrcDstDimAccessOrder>{}, "wrong! map is not valid");
static_assert( static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] %
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0, math::lcm(SrcDataPerRead, DstDataPerWrite) ==
0,
"wrong! cannot evenly divide"); "wrong! cannot evenly divide");
// TODO:: sanity-check if vectorized memory access is allowed on src and dst // TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
} }
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2() __device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
...@@ -67,22 +67,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -67,22 +67,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
mDstSliceOrigin = dst_slice_origin; mDstSliceOrigin = dst_slice_origin;
} }
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const __device__ void Run(const SrcData* p_src, DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&]( ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}([&](
auto long_vector_access_id) { auto long_vector_access_id) {
// data id w.r.t slicing-window // data id w.r.t slicing-window
...@@ -109,13 +107,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -109,13 +107,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src vector's padding situation, only check the first data in this src // Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the same padding situation // has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerAccess, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::none>(
...@@ -141,13 +139,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -141,13 +139,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst vector's padding situation, only check the first data in this dst // Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the same padding situation // has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<DstData, move_data<DstData,
DstDataPerAccess, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
...@@ -165,20 +163,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -165,20 +163,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
return Sequence<(Mask ? Lengths : 1)...>{}; return Sequence<(Mask ? Lengths : 1)...>{};
} }
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// This version is optimized for address calculation of src tensor // This version is optimized for address calculation of src tensor
// TODO: this function is not compiled to expected ISA // TODO: this function is not compiled to expected ISA
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src, __device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
DstData* p_dst) const DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
...@@ -187,9 +185,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -187,9 +185,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask(); constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask();
constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask(); constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask();
static_assert(src_linear_dim_mask.At(VectorAccessDim) || static_assert(
long_vector_size == SrcDataPerAccess, src_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == SrcDataPerRead,
"Warning! VectorAccessDim is not SrcDesc's linear dimension, performance " "Warning! SrcDstVectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop"); "would drop");
// separate steps into linear and non-linear components, accoording to src tensor // separate steps into linear and non-linear components, accoording to src tensor
...@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
// Loop over VectorAccessDim, and load data from src to the // Loop over SrcDstVectorReadWriteDim, and load data from src to the
// long-vector buffer. // long-vector buffer.
// If VectorAccessDim is src's linear dimension, then src's // If SrcDstVectorReadWriteDim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If // offset-diff due to this looping is known at compile-time. If
// VectorAccessDim is src's nonlinear dimension, then src's // SrcDstVectorReadWriteDim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best // offset-diff due to this looping is only known at run-time. For best
// performance, VectorAccessDim, should be src's linear dimension // performance, SrcDstVectorReadWriteDim, should be src's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
...@@ -258,13 +256,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -258,13 +256,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
src_coord.GetOffset() - src_nonlinear_coord.GetOffset(); src_coord.GetOffset() - src_nonlinear_coord.GetOffset();
#endif #endif
// Check src vector's padding situation, only check the first data in // Check src data's valid mapping situation, only check the first data in this
// this src vector. It's user's responsiblity to make sure all data in // src
// the src vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerAccess, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src, InMemoryDataOperation::none>(p_src,
...@@ -296,13 +295,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -296,13 +295,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps + const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id); linear_dim_data_steps + scalar_id);
// Check dst vector's padding situation, only check the first data in // Check dst data's valid mapping situation, only check the first data in this
// this dst vector. It's user's responsiblity to make sure all data in // dst
// the dst vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<DstData, move_data<DstData,
DstDataPerAccess, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
...@@ -313,20 +313,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -313,20 +313,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
}); });
} }
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor // This version is optimized for address calculation of dst tensor
// TODO: this function is not compiled to expected ISA // TODO: this function is not compiled to expected ISA
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src, __device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
DstData* p_dst) const DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
...@@ -335,9 +333,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -335,9 +333,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask(); constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask();
constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask(); constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask();
static_assert(dst_linear_dim_mask.At(VectorAccessDim) || static_assert(
long_vector_size == DstDataPerAccess, dst_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == DstDataPerWrite,
"Warning! VectorAccessDim is not DstDesc's linear dimension, performance " "Warning! SrcDstVectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop"); "would drop");
// separate steps into linear and non-linear components, accoording to dst tensor // separate steps into linear and non-linear components, accoording to dst tensor
...@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
// Loop over VectorAccessDim, and load data from src to the // Loop over SrcDstVectorReadWriteDim, and load data from src to the
// long-vector buffer. // long-vector buffer.
// If VectorAccessDim is dst's linear dimension, then dst's // If SrcDstVectorReadWriteDim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If // offset-diff due to this looping is known at compile-time. If
// VectorAccessDim is dst's nonlinear dimension, then dst's // SrcDstVectorReadWriteDim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best // offset-diff due to this looping is only known at run-time. For best
// performance, VectorAccessDim, should be dst's linear dimension // performance, SrcDstVectorReadWriteDim, should be dst's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
...@@ -397,13 +395,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -397,13 +395,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps + const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id); linear_dim_data_steps + scalar_id);
// Check src vector's padding situation, only check the first data in // Check src data's valid mapping situation, only check the first data in this
// this src vector. It's user's responsiblity to make sure all data in // src
// the src vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerAccess, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::none>(
...@@ -441,13 +440,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -441,13 +440,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset(); dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
#endif #endif
// Check dst vector's padding situation, only check the first data in // Check dst data's valid mapping situation, only check the first data in this
// this dst vector. It's user's responsiblity to make sure all data in // dst
// the dst vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<DstData, move_data<DstData,
DstDataPerAccess, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>(p_dst_long_vector, DstInMemOp>(p_dst_long_vector,
......
...@@ -11,8 +11,8 @@ template <typename T, ...@@ -11,8 +11,8 @@ template <typename T,
typename OutDesc, typename OutDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InLeftPads,
typename RightPads> typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc, WeiDesc wei_kcyx_desc,
...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
const Tensor<T>& out_nkhw, const Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
std::size_t nrepeat) std::size_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -62,24 +62,26 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -62,24 +62,26 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t GemmM = C * Y * X; constexpr index_t GemmM = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo; constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -93,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -93,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyClusterLengths, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopyDataPerAccess, GemmABlockCopySrcDataPerRead_GemmM,
GemmBBlockCopySubLengths, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyClusterLengths, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyDataPerAccess, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmCThreadCopyDataPerAccess>{}; GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -11,8 +11,8 @@ template <typename T, ...@@ -11,8 +11,8 @@ template <typename T,
typename OutDesc, typename OutDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InLeftPads,
typename RightPads> typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc, WeiDesc wei_kcyx_desc,
...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
const Tensor<T>& out_nkhw, const Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
std::size_t nrepeat) std::size_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -68,54 +68,26 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -68,54 +68,26 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#endif #endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
...@@ -126,12 +98,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -126,12 +98,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Htilda = Ho + right_pad_ho; constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo; constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda; constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -145,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -145,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -159,13 +130,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -159,13 +130,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyClusterLengths, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopyDataPerAccess, GemmABlockCopySrcDataPerRead_GemmM,
GemmBBlockCopySubLengths, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyClusterLengths, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyDataPerAccess, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmCThreadCopyDataPerAccess>{}; GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp" #include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T, template <class T,
class InDesc, class InDesc,
...@@ -11,8 +11,8 @@ template <class T, ...@@ -11,8 +11,8 @@ template <class T,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
class LeftPads, class InLeftPads,
class RightPads> class InRightPads>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
...@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{}; constexpr auto in_nchw_desc =
constexpr auto wei_kcyx_desc = WeiDesc{}; make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1); constexpr index_t K = out_nkhw_desc.GetLength(I1);
...@@ -51,12 +54,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -51,12 +54,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 1
// BlockSize = 256, EPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -65,35 +68,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -65,35 +68,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<4, 1>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 1; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// BlockSize = 256, EPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -102,35 +100,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -102,35 +100,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<1, 4>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t OutThreadCopyDataPerAccess_B = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 0
// BlockSize = 256, EPerBlock = 16 // BlockSize = 256, GemmKPerBlock = 16
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t EPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -139,34 +132,29 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -139,34 +132,29 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 4>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t OutThreadCopyDataPerAccess_B = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 1
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -175,51 +163,47 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -175,51 +163,47 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 2>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<4, 64>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 2; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t OutThreadCopyDataPerAccess_B = 2; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
#endif #endif
constexpr index_t B = N * Ho * Wo; constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw<
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
GridSize, GridSize,
BlockSize, BlockSize,
T, T,
T,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
BPerBlock, GemmMPerBlock,
KPerBlock, GemmNPerBlock,
EPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
...@@ -227,22 +211,17 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -227,22 +211,17 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, ThreadGemmDataPerReadM,
GemmDataPerReadB, ThreadGemmDataPerReadN,
InBlockCopySubLengths_E_B, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
InBlockCopyClusterLengths_E_B, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
InBlockCopyThreadClusterArrangeOrder, GemmABlockCopySrcDataPerRead_GemmK,
InBlockCopySrcAccessOrder, GemmABlockCopyDstDataPerWrite_GemmM,
InBlockCopyDstAccessOrder, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
InBlockCopyDataPerAccess_B, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
WeiBlockCopySubLengths_E_K, GemmBBlockCopySrcDataPerRead_GemmN,
WeiBlockCopyClusterLengths_E_K, GemmBBlockCopyDstDataPerWrite_GemmN,
WeiBlockCopyThreadClusterArrangeOrder, GemmCThreadCopyDstDataPerWrite_GemmN1>{};
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -21,7 +21,7 @@ int main(int argc, char* argv[]) ...@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
constexpr index_t N = 8; constexpr index_t N = 8;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 16; constexpr index_t HI = 16;
......
...@@ -43,7 +43,7 @@ int main(int argc, char* argv[]) ...@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -250,7 +250,7 @@ int main(int argc, char* argv[]) ...@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -296,7 +296,7 @@ int main(int argc, char* argv[]) ...@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -403,7 +403,7 @@ int main(int argc, char* argv[]) ...@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment