Commit 19c3c9a8 authored by Chao Liu's avatar Chao Liu
Browse files

updated bwd-data v1r1 and v2r1 to use gridwise gemm

parent 19a93dac
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_GemmNACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_GemmNACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
......@@ -17,11 +17,11 @@ template <index_t GridSize,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
......@@ -31,13 +31,15 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename WeiBlockCopySubLengths_K_E,
typename WeiBlockCopyClusterLengths_K_E,
index_t WeiBlockCopyDataPerAccess_E,
typename OutBlockCopySubLengths_K_B,
typename OutBlockCopyClusterLengths_K_B,
index_t OutBlockCopyDataPerAccess_B,
index_t InThreadCopyDataPerAccess_B>
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmN,
index_t GemmABlockCopyDstDataPerWrite_GemmN,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
......@@ -49,8 +51,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
......@@ -73,14 +73,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// output tensor
constexpr auto out_n_k_howo_global_desc =
......@@ -99,8 +98,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
......@@ -121,33 +121,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// GEMM: atomic add
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add,
EPerBlock,
BPerBlock,
KPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
WeiBlockCopySubLengths_K_E,
WeiBlockCopyClusterLengths_K_E,
WeiBlockCopyDataPerAccess_E,
OutBlockCopySubLengths_K_B,
OutBlockCopyClusterLengths_K_B,
OutBlockCopyDataPerAccess_B,
InThreadCopyDataPerAccess_B>{};
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmN,
GemmABlockCopyDstDataPerWrite_GemmN,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
......
......@@ -34,14 +34,15 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M
index_t GemmABlockCopyDataPerAccess, // Gemm-M
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N
index_t GemmBBlockCopyDataPerAccess, // Gemm-N
index_t GemmCThreadCopyDataPerAccess // Gemm-N
>
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
......@@ -71,10 +72,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
......@@ -172,33 +175,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
......
......@@ -18,10 +18,10 @@ template <index_t BlockSize,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess,
index_t SrcVectoReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic,
......@@ -146,8 +146,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadBufferDesc,
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcDataPerAccess,
SrcVectoReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
ThreadBufferAddressSpace,
......@@ -157,9 +157,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorAccessDim,
DstVectorWriteDim,
1,
DstDataPerAccess,
DstDataPerWrite,
ThreadBufferAddressSpace,
DstAddressSpace,
DstInMemOp>;
......
......@@ -31,14 +31,21 @@ template <index_t GridSize,
index_t KPerThreadLoop,
index_t ThreadGemmDataPerReadM,
index_t ThreadGemmDataPerReadN,
typename ABlockCopySubLengths_K_M,
typename ABlockCopyClusterLengths_K_M,
index_t ABlockCopyDataPerAccess_M,
typename BBlockCopySubLengths_K_N,
typename BBlockCopyClusterLengths_K_N,
index_t BBlockCopyDataPerAccess_N,
index_t CThreadCopyDataPerAccess_N>
struct GridwiseGemmTransposedANormalBNormalC_v1r1
typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyThreadClusterLengths_K_M,
typename ABlockCopyThreadClusterArrangeOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
typename BBlockCopyThreadSliceLengths_K_N,
typename BBlockCopyThreadClusterLengths_K_N,
typename BBlockCopyThreadClusterArrangeOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
index_t CThreadCopyVectorReadWriteDim,
index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1
{
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
......@@ -55,8 +62,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
BBlockCopyDataPerAccess_N,
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
......@@ -86,15 +93,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopySubLengths_K_M,
ABlockCopyClusterLengths_K_M,
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
ABlockCopySrcVectorReadDim,
1,
ABlockCopyDataPerAccess_M,
ABlockCopyDataPerAccess_M,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
......@@ -112,15 +119,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopySubLengths_K_N,
BBlockCopyClusterLengths_K_N,
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
BBlockCopySrcVectorReadDim,
1,
BBlockCopyDataPerAccess_N,
BBlockCopyDataPerAccess_N,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
......@@ -305,9 +312,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3>,
3,
CThreadCopyDataPerAccess_N,
CThreadCopyDataPerAccess_N,
CThreadCopyVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::vgpr,
AddressSpace::global,
CGlobalMemoryDataOperation>(
......
......@@ -15,13 +15,15 @@ namespace ck {
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t VectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess,
index_t VectorReadWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
......@@ -45,10 +47,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid");
static_assert(
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0,
SliceLengths{}[VectorReadWriteDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
"wrong! cannot evenly divide");
// TODO:: sanity-check if vectorized memory access is allowed on src and dst
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
......@@ -67,17 +69,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
mDstSliceOrigin = dst_slice_origin;
}
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto vector_access_dim = Number<VectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
......@@ -109,13 +109,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src vector's padding situation, only check the first data in this src
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the same padding situation
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
{
move_data<SrcData,
SrcDataPerAccess,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
......@@ -141,13 +141,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst vector's padding situation, only check the first data in this dst
// Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the same padding situation
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
{
move_data<DstData,
DstDataPerAccess,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
......@@ -165,20 +165,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
return Sequence<(Mask ? Lengths : 1)...>{};
}
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// This version is optimized for address calculation of src tensor
// TODO: this function is not compiled to expected ISA
template <typename SrcData, typename DstData>
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
DstData* p_dst) const
{
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto vector_access_dim = Number<VectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
......@@ -187,9 +187,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask();
constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask();
static_assert(src_linear_dim_mask.At(VectorAccessDim) ||
long_vector_size == SrcDataPerAccess,
"Warning! VectorAccessDim is not SrcDesc's linear dimension, performance "
static_assert(src_linear_dim_mask.At(VectorReadWriteDim) ||
long_vector_size == SrcDataPerRead,
"Warning! VectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop");
// separate steps into linear and non-linear components, accoording to src tensor
......@@ -230,13 +230,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0;
}
// Loop over VectorAccessDim, and load data from src to the
// Loop over VectorReadWriteDim, and load data from src to the
// long-vector buffer.
// If VectorAccessDim is src's linear dimension, then src's
// If VectorReadWriteDim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If
// VectorAccessDim is src's nonlinear dimension, then src's
// VectorReadWriteDim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best
// performance, VectorAccessDim, should be src's linear dimension
// performance, VectorReadWriteDim, should be src's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
......@@ -258,13 +258,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
src_coord.GetOffset() - src_nonlinear_coord.GetOffset();
#endif
// Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
{
move_data<SrcData,
SrcDataPerAccess,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src,
......@@ -296,13 +297,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id);
// Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
{
move_data<DstData,
DstDataPerAccess,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
......@@ -313,20 +315,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
});
}
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor
// TODO: this function is not compiled to expected ISA
template <typename SrcData, typename DstData>
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
DstData* p_dst) const
{
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto vector_access_dim = Number<VectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
......@@ -335,9 +335,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask();
constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask();
static_assert(dst_linear_dim_mask.At(VectorAccessDim) ||
long_vector_size == DstDataPerAccess,
"Warning! VectorAccessDim is not DstDesc's linear dimension, performance "
static_assert(dst_linear_dim_mask.At(VectorReadWriteDim) ||
long_vector_size == DstDataPerWrite,
"Warning! VectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop");
// separate steps into linear and non-linear components, accoording to dst tensor
......@@ -378,13 +378,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0;
}
// Loop over VectorAccessDim, and load data from src to the
// Loop over VectorReadWriteDim, and load data from src to the
// long-vector buffer.
// If VectorAccessDim is dst's linear dimension, then dst's
// If VectorReadWriteDim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If
// VectorAccessDim is dst's nonlinear dimension, then dst's
// VectorReadWriteDim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best
// performance, VectorAccessDim, should be dst's linear dimension
// performance, VectorReadWriteDim, should be dst's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
......@@ -397,13 +397,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id);
// Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
{
move_data<SrcData,
SrcDataPerAccess,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
......@@ -441,13 +442,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
#endif
// Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
{
move_data<DstData,
DstDataPerAccess,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(p_dst_long_vector,
......
......@@ -62,17 +62,19 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; // Gemm-M
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t GemmM = C * Y * X;
......@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i)
{
......
......@@ -68,45 +68,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
......@@ -126,7 +100,6 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
......@@ -159,13 +132,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
for(index_t i = 0; i < nrepeat; ++i)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment