Commit 19c3c9a8 authored by Chao Liu's avatar Chao Liu
Browse files

updated bwd-data v1r1 and v2r1 to use gridwise gemm

parent 19a93dac
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP #ifndef CK_GRIDWISE_CONVOLUTION_GemmNACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP #define CK_GRIDWISE_CONVOLUTION_GemmNACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -17,11 +17,11 @@ template <index_t GridSize, ...@@ -17,11 +17,11 @@ template <index_t GridSize,
typename OutGlobalDesc, typename OutGlobalDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InLeftPads,
typename RightPads, typename InRightPads,
index_t EPerBlock, index_t GemmMPerBlock,
index_t BPerBlock, index_t GemmNPerBlock,
index_t KPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
...@@ -31,13 +31,15 @@ template <index_t GridSize, ...@@ -31,13 +31,15 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM, index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN, index_t GemmThreadGemmDataPerReadN,
typename WeiBlockCopySubLengths_K_E, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename WeiBlockCopyClusterLengths_K_E, typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t WeiBlockCopyDataPerAccess_E, index_t GemmABlockCopySrcDataPerRead_GemmN,
typename OutBlockCopySubLengths_K_B, index_t GemmABlockCopyDstDataPerWrite_GemmN,
typename OutBlockCopyClusterLengths_K_B, typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
index_t OutBlockCopyDataPerAccess_B, typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t InThreadCopyDataPerAccess_B> index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{ {
__device__ void Run(Float* __restrict__ p_in_global, __device__ void Run(Float* __restrict__ p_in_global,
...@@ -49,8 +51,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -49,8 +51,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
...@@ -73,14 +73,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -73,14 +73,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) && // TODO: this logic may not be correct for bwd-data
(X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0), static_assert(
"wrong! aligment requirement for vectorized global load of input tensor will " (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
"be violated"); (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// output tensor // output tensor
constexpr auto out_n_k_howo_global_desc = constexpr auto out_n_k_howo_global_desc =
...@@ -99,8 +98,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -99,8 +98,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// input tensor // input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple( make_tuple(PassThrough<N>{},
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}), PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
...@@ -121,33 +121,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -121,33 +121,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// GEMM: atomic add // GEMM: atomic add
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
decltype(wei_k_e_global_desc), decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc), decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add, InMemoryDataOperation::atomic_add,
EPerBlock, GemmMPerBlock,
BPerBlock, GemmNPerBlock,
KPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
WeiBlockCopySubLengths_K_E, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
WeiBlockCopyClusterLengths_K_E, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
WeiBlockCopyDataPerAccess_E, Sequence<0, 1>,
OutBlockCopySubLengths_K_B, 1,
OutBlockCopyClusterLengths_K_B, GemmABlockCopySrcDataPerRead_GemmN,
OutBlockCopyDataPerAccess_B, GemmABlockCopyDstDataPerWrite_GemmN,
InThreadCopyDataPerAccess_B>{}; GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
} }
......
...@@ -34,14 +34,15 @@ template <index_t GridSize, ...@@ -34,14 +34,15 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM, index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN, index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopyDataPerAccess, // Gemm-M index_t GemmABlockCopySrcDataPerRead_GemmM,
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
index_t GemmBBlockCopyDataPerAccess, // Gemm-N typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmCThreadCopyDataPerAccess // Gemm-N index_t GemmBBlockCopySrcDataPerRead_GemmN,
> index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{ {
__device__ void Run(Float* __restrict__ p_in_global, __device__ void Run(Float* __restrict__ p_in_global,
...@@ -71,10 +72,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -71,10 +72,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) && // TODO: this logic may not be correct for bwd-data
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0), static_assert(
"wrong! aligment requirement for vectorized global load of input tensor will " (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
"be violated"); (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
...@@ -172,33 +175,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -172,33 +175,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
// GEMM // GEMM
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc), decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc), decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none, InMemoryDataOperation::none,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyClusterLengths, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopyDataPerAccess, Sequence<0, 1>,
GemmBBlockCopySubLengths, 1,
GemmBBlockCopyClusterLengths, GemmABlockCopySrcDataPerRead_GemmM,
GemmBBlockCopyDataPerAccess, GemmABlockCopyDstDataPerWrite_GemmM,
GemmCThreadCopyDataPerAccess>{}; GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
} }
......
...@@ -18,10 +18,10 @@ template <index_t BlockSize, ...@@ -18,10 +18,10 @@ template <index_t BlockSize,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder, typename SrcDimAccessOrder,
typename DstDimAccessOrder, typename DstDimAccessOrder,
index_t SrcVectorAccessDim, index_t SrcVectoReadDim,
index_t DstVectorAccessDim, index_t DstVectorWriteDim,
index_t SrcDataPerAccess, index_t SrcDataPerRead,
index_t DstDataPerAccess, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic, AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic, AddressSpace DstAddressSpace = AddressSpace::generic,
...@@ -146,8 +146,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -146,8 +146,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadBufferDesc, ThreadBufferDesc,
ThreadSliceLengths, ThreadSliceLengths,
SrcDimAccessOrder, SrcDimAccessOrder,
SrcVectorAccessDim, SrcVectoReadDim,
SrcDataPerAccess, SrcDataPerRead,
1, 1,
SrcAddressSpace, SrcAddressSpace,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
...@@ -157,9 +157,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -157,9 +157,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockDstDesc, BlockDstDesc,
ThreadSliceLengths, ThreadSliceLengths,
DstDimAccessOrder, DstDimAccessOrder,
DstVectorAccessDim, DstVectorWriteDim,
1, 1,
DstDataPerAccess, DstDataPerWrite,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
DstAddressSpace, DstAddressSpace,
DstInMemOp>; DstInMemOp>;
......
...@@ -31,14 +31,21 @@ template <index_t GridSize, ...@@ -31,14 +31,21 @@ template <index_t GridSize,
index_t KPerThreadLoop, index_t KPerThreadLoop,
index_t ThreadGemmDataPerReadM, index_t ThreadGemmDataPerReadM,
index_t ThreadGemmDataPerReadN, index_t ThreadGemmDataPerReadN,
typename ABlockCopySubLengths_K_M, typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyClusterLengths_K_M, typename ABlockCopyThreadClusterLengths_K_M,
index_t ABlockCopyDataPerAccess_M, typename ABlockCopyThreadClusterArrangeOrder,
typename BBlockCopySubLengths_K_N, index_t ABlockCopySrcVectorReadDim,
typename BBlockCopyClusterLengths_K_N, index_t ABlockCopySrcDataPerRead,
index_t BBlockCopyDataPerAccess_N, index_t ABlockCopyDstDataPerWrite_M,
index_t CThreadCopyDataPerAccess_N> typename BBlockCopyThreadSliceLengths_K_N,
struct GridwiseGemmTransposedANormalBNormalC_v1r1 typename BBlockCopyThreadClusterLengths_K_N,
typename BBlockCopyThreadClusterArrangeOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
index_t CThreadCopyVectorReadWriteDim,
index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
...@@ -55,8 +62,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -55,8 +62,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
constexpr auto N = b_k_n_global_desc.GetLengths()[1]; constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDataPerAccess_N, BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM, ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN); ThreadGemmDataPerReadN);
...@@ -86,15 +93,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -86,15 +93,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()), decltype(a_k_m_block_desc.GetLengths()),
ABlockCopySubLengths_K_M, ABlockCopyThreadSliceLengths_K_M,
ABlockCopyClusterLengths_K_M, ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, ABlockCopySrcVectorReadDim,
1,
1, 1,
ABlockCopyDataPerAccess_M, ABlockCopySrcDataPerRead,
ABlockCopyDataPerAccess_M, ABlockCopyDstDataPerWrite_M,
AddressSpace::global, AddressSpace::global,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::lds, AddressSpace::lds,
...@@ -112,15 +119,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -112,15 +119,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()), decltype(b_k_n_block_desc.GetLengths()),
BBlockCopySubLengths_K_N, BBlockCopyThreadSliceLengths_K_N,
BBlockCopyClusterLengths_K_N, BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, BBlockCopySrcVectorReadDim,
1,
1, 1,
BBlockCopyDataPerAccess_N, BBlockCopySrcDataPerRead,
BBlockCopyDataPerAccess_N, BBlockCopyDstDataPerWrite_N,
AddressSpace::global, AddressSpace::global,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::lds, AddressSpace::lds,
...@@ -305,9 +312,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -305,9 +312,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()), decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, CThreadCopyVectorReadWriteDim,
CThreadCopyDataPerAccess_N, 1,
CThreadCopyDataPerAccess_N, CThreadCopyDstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::global, AddressSpace::global,
CGlobalMemoryDataOperation>( CGlobalMemoryDataOperation>(
......
...@@ -15,13 +15,15 @@ namespace ck { ...@@ -15,13 +15,15 @@ namespace ck {
// The dimension access order should be the same on src and dst. // The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and // It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS // the other is device memory or LDS
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <typename SrcDesc, template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorAccessDim, index_t VectorReadWriteDim,
index_t SrcDataPerAccess, index_t SrcDataPerRead,
index_t DstDataPerAccess, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic, AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none> InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
...@@ -45,10 +47,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -45,10 +47,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid"); static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid");
static_assert( static_assert(
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0, SliceLengths{}[VectorReadWriteDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
"wrong! cannot evenly divide"); "wrong! cannot evenly divide");
// TODO:: sanity-check if vectorized memory access is allowed on src and dst // TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
} }
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2() __device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
...@@ -67,17 +69,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -67,17 +69,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
mDstSliceOrigin = dst_slice_origin; mDstSliceOrigin = dst_slice_origin;
} }
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const __device__ void Run(const SrcData* p_src, DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<VectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
...@@ -109,13 +109,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -109,13 +109,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src vector's padding situation, only check the first data in this src // Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
// has the same padding situation // has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerAccess, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::none>(
...@@ -141,13 +141,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -141,13 +141,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst vector's padding situation, only check the first data in this dst // Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector // vector. It's user's responsiblity to make sure all data in the dst vector
// has the same padding situation // has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<DstData, move_data<DstData,
DstDataPerAccess, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
...@@ -165,20 +165,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -165,20 +165,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
return Sequence<(Mask ? Lengths : 1)...>{}; return Sequence<(Mask ? Lengths : 1)...>{};
} }
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// This version is optimized for address calculation of src tensor // This version is optimized for address calculation of src tensor
// TODO: this function is not compiled to expected ISA // TODO: this function is not compiled to expected ISA
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src, __device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
DstData* p_dst) const DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<VectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
...@@ -187,9 +187,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -187,9 +187,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask(); constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask();
constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask(); constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask();
static_assert(src_linear_dim_mask.At(VectorAccessDim) || static_assert(src_linear_dim_mask.At(VectorReadWriteDim) ||
long_vector_size == SrcDataPerAccess, long_vector_size == SrcDataPerRead,
"Warning! VectorAccessDim is not SrcDesc's linear dimension, performance " "Warning! VectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop"); "would drop");
// separate steps into linear and non-linear components, accoording to src tensor // separate steps into linear and non-linear components, accoording to src tensor
...@@ -230,13 +230,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -230,13 +230,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
// Loop over VectorAccessDim, and load data from src to the // Loop over VectorReadWriteDim, and load data from src to the
// long-vector buffer. // long-vector buffer.
// If VectorAccessDim is src's linear dimension, then src's // If VectorReadWriteDim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If // offset-diff due to this looping is known at compile-time. If
// VectorAccessDim is src's nonlinear dimension, then src's // VectorReadWriteDim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best // offset-diff due to this looping is only known at run-time. For best
// performance, VectorAccessDim, should be src's linear dimension // performance, VectorReadWriteDim, should be src's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
...@@ -258,13 +258,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -258,13 +258,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
src_coord.GetOffset() - src_nonlinear_coord.GetOffset(); src_coord.GetOffset() - src_nonlinear_coord.GetOffset();
#endif #endif
// Check src vector's padding situation, only check the first data in // Check src data's valid mapping situation, only check the first data in this
// this src vector. It's user's responsiblity to make sure all data in // src
// the src vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerAccess, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src, InMemoryDataOperation::none>(p_src,
...@@ -296,13 +297,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -296,13 +297,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps + const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id); linear_dim_data_steps + scalar_id);
// Check dst vector's padding situation, only check the first data in // Check dst data's valid mapping situation, only check the first data in this
// this dst vector. It's user's responsiblity to make sure all data in // dst
// the dst vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<DstData, move_data<DstData,
DstDataPerAccess, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
...@@ -313,20 +315,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -313,20 +315,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
}); });
} }
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor // This version is optimized for address calculation of dst tensor
// TODO: this function is not compiled to expected ISA // TODO: this function is not compiled to expected ISA
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src, __device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
DstData* p_dst) const DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<VectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
...@@ -335,9 +335,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -335,9 +335,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask(); constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask();
constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask(); constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask();
static_assert(dst_linear_dim_mask.At(VectorAccessDim) || static_assert(dst_linear_dim_mask.At(VectorReadWriteDim) ||
long_vector_size == DstDataPerAccess, long_vector_size == DstDataPerWrite,
"Warning! VectorAccessDim is not DstDesc's linear dimension, performance " "Warning! VectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop"); "would drop");
// separate steps into linear and non-linear components, accoording to dst tensor // separate steps into linear and non-linear components, accoording to dst tensor
...@@ -378,13 +378,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -378,13 +378,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
// Loop over VectorAccessDim, and load data from src to the // Loop over VectorReadWriteDim, and load data from src to the
// long-vector buffer. // long-vector buffer.
// If VectorAccessDim is dst's linear dimension, then dst's // If VectorReadWriteDim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If // offset-diff due to this looping is known at compile-time. If
// VectorAccessDim is dst's nonlinear dimension, then dst's // VectorReadWriteDim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best // offset-diff due to this looping is only known at run-time. For best
// performance, VectorAccessDim, should be dst's linear dimension // performance, VectorReadWriteDim, should be dst's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
...@@ -397,13 +397,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -397,13 +397,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps + const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id); linear_dim_data_steps + scalar_id);
// Check src vector's padding situation, only check the first data in // Check src data's valid mapping situation, only check the first data in this
// this src vector. It's user's responsiblity to make sure all data in // src
// the src vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<SrcData, move_data<SrcData,
SrcDataPerAccess, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::none>(
...@@ -441,13 +442,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -441,13 +442,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset(); dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
#endif #endif
// Check dst vector's padding situation, only check the first data in // Check dst data's valid mapping situation, only check the first data in this
// this dst vector. It's user's responsiblity to make sure all data in // dst
// the dst vector has the same padding situation // vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
move_data<DstData, move_data<DstData,
DstDataPerAccess, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>(p_dst_long_vector, DstInMemOp>(p_dst_long_vector,
......
...@@ -62,17 +62,19 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -62,17 +62,19 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; // Gemm-M
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t GemmM = C * Y * X; constexpr index_t GemmM = C * Y * X;
...@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyClusterLengths, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopyDataPerAccess, GemmABlockCopySrcDataPerRead_GemmM,
GemmBBlockCopySubLengths, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyClusterLengths, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyDataPerAccess, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmCThreadCopyDataPerAccess>{}; GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -68,45 +68,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -68,45 +68,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#endif #endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for // TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
...@@ -126,7 +100,6 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -126,7 +100,6 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Htilda = Ho + right_pad_ho; constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo; constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda; constexpr index_t GemmN = N * Htilda * Wtilda;
...@@ -159,13 +132,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -159,13 +132,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyClusterLengths, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopyDataPerAccess, GemmABlockCopySrcDataPerRead_GemmM,
GemmBBlockCopySubLengths, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyClusterLengths, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyDataPerAccess, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmCThreadCopyDataPerAccess>{}; GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment