Commit 9359ebfd authored by Chao Liu's avatar Chao Liu
Browse files

updated fwd v4r4 to use gridwise gemm

parent 19c3c9a8
...@@ -144,15 +144,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -144,15 +144,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>,
1, 1,
GemmABlockCopySrcDataPerRead_GemmN, GemmABlockCopySrcDataPerRead_GemmN,
GemmABlockCopyDstDataPerWrite_GemmN, GemmABlockCopyDstDataPerWrite_GemmN,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>,
1, 1,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3, 3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>{};
......
...@@ -198,15 +198,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -198,15 +198,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>,
1, 1,
GemmABlockCopySrcDataPerRead_GemmM, GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>,
1, 1,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3, 3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>{};
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 &&
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// weight tensor
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_e_k_global_desc),
decltype(in_e_b_global_desc),
decltype(out_k_b_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc =
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment
constexpr auto in_e_b_block_desc =
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// global mem
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// LDS
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// weight blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat =
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(c_k0k1_b0b1_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store last data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
}
}
// copy output: register to global memory
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
// src descriptor
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
// dst descriptor
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t K0 = K / K1;
constexpr index_t B0 = B / B1;
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
out_k_b_global_desc,
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B,
AddressSpace::vgpr,
AddressSpace::global>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1})
.Run(p_out_thread, p_out_global);
}
}
};
} // namespace ck
#endif
...@@ -9,6 +9,12 @@ ...@@ -9,6 +9,12 @@
namespace ck { namespace ck {
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <index_t BlockSize, template <index_t BlockSize,
typename BlockSrcDesc, typename BlockSrcDesc,
typename BlockDstDesc, typename BlockDstDesc,
......
...@@ -34,16 +34,19 @@ template <index_t GridSize, ...@@ -34,16 +34,19 @@ template <index_t GridSize,
typename ABlockCopyThreadSliceLengths_K_M, typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyThreadClusterLengths_K_M, typename ABlockCopyThreadClusterLengths_K_M,
typename ABlockCopyThreadClusterArrangeOrder, typename ABlockCopyThreadClusterArrangeOrder,
typename ABlockCopySrcAccessOrder,
index_t ABlockCopySrcVectorReadDim, index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead, index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M, index_t ABlockCopyDstDataPerWrite_M,
typename BBlockCopyThreadSliceLengths_K_N, typename BBlockCopyThreadSliceLengths_K_N,
typename BBlockCopyThreadClusterLengths_K_N, typename BBlockCopyThreadClusterLengths_K_N,
typename BBlockCopyThreadClusterArrangeOrder, typename BBlockCopyThreadClusterArrangeOrder,
typename BBlockCopySrcAccessOrder,
index_t BBlockCopySrcVectorReadDim, index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead, index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N, index_t BBlockCopyDstDataPerWrite_N,
index_t CThreadCopyVectorReadWriteDim, typename CThreadCopySrcDstAccessOrder,
index_t CThreadCopySrcDstVectorReadWriteDim,
index_t CThreadCopyDstDataPerWrite> index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1 struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
...@@ -96,7 +99,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -96,7 +99,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ABlockCopyThreadSliceLengths_K_M, ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M, ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder, ABlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>, ABlockCopySrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockCopySrcVectorReadDim, ABlockCopySrcVectorReadDim,
1, 1,
...@@ -122,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -122,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
BBlockCopyThreadSliceLengths_K_N, BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N, BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder, BBlockCopyThreadClusterArrangeOrder,
Sequence<0, 1>, BBlockCopySrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
BBlockCopySrcVectorReadDim, BBlockCopySrcVectorReadDim,
1, 1,
...@@ -311,8 +314,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -311,8 +314,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc), ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()), decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3>, CThreadCopySrcDstAccessOrder,
CThreadCopyVectorReadWriteDim, CThreadCopySrcDstVectorReadWriteDim,
1, 1,
CThreadCopyDstDataPerWrite, CThreadCopyDstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
......
...@@ -8,20 +8,17 @@ ...@@ -8,20 +8,17 @@
namespace ck { namespace ck {
// This version use multi-index transformation
// This threadwise copy allow vector access of src and dst. // This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst. // It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst. // The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst. // The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping // Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <typename SrcDesc, template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename SrcDstDimAccessOrder,
index_t VectorReadWriteDim, index_t SrcDstVectorReadWriteDim,
index_t SrcDataPerRead, index_t SrcDataPerRead,
index_t DstDataPerWrite, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::generic,
...@@ -41,13 +38,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -41,13 +38,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
static_assert(nDim == SrcDesc::GetNumOfDimension() && static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() && nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
nDim == DimAccessOrder::Size(), nDim == SrcDstDimAccessOrder::Size(),
"wrong! # of dimensions not the same"); "wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid"); static_assert(is_valid_sequence_map<SrcDstDimAccessOrder>{}, "wrong! map is not valid");
static_assert( static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] %
SliceLengths{}[VectorReadWriteDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0, math::lcm(SrcDataPerRead, DstDataPerWrite) ==
0,
"wrong! cannot evenly divide"); "wrong! cannot evenly divide");
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst // TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
...@@ -72,7 +70,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -72,7 +70,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const __device__ void Run(const SrcData* p_src, DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorReadWriteDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
...@@ -82,7 +80,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -82,7 +80,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&]( ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}([&](
auto long_vector_access_id) { auto long_vector_access_id) {
// data id w.r.t slicing-window // data id w.r.t slicing-window
...@@ -173,7 +171,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -173,7 +171,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src, __device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
DstData* p_dst) const DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorReadWriteDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
...@@ -187,9 +185,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -187,9 +185,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask(); constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask();
constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask(); constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask();
static_assert(src_linear_dim_mask.At(VectorReadWriteDim) || static_assert(
long_vector_size == SrcDataPerRead, src_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == SrcDataPerRead,
"Warning! VectorReadWriteDim is not SrcDesc's linear dimension, performance " "Warning! SrcDstVectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop"); "would drop");
// separate steps into linear and non-linear components, accoording to src tensor // separate steps into linear and non-linear components, accoording to src tensor
...@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
// Loop over VectorReadWriteDim, and load data from src to the // Loop over SrcDstVectorReadWriteDim, and load data from src to the
// long-vector buffer. // long-vector buffer.
// If VectorReadWriteDim is src's linear dimension, then src's // If SrcDstVectorReadWriteDim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If // offset-diff due to this looping is known at compile-time. If
// VectorReadWriteDim is src's nonlinear dimension, then src's // SrcDstVectorReadWriteDim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best // offset-diff due to this looping is only known at run-time. For best
// performance, VectorReadWriteDim, should be src's linear dimension // performance, SrcDstVectorReadWriteDim, should be src's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
...@@ -321,7 +319,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -321,7 +319,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src, __device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
DstData* p_dst) const DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<VectorReadWriteDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
...@@ -335,9 +333,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -335,9 +333,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask(); constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask();
constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask(); constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask();
static_assert(dst_linear_dim_mask.At(VectorReadWriteDim) || static_assert(
long_vector_size == DstDataPerWrite, dst_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == DstDataPerWrite,
"Warning! VectorReadWriteDim is not DstDesc's linear dimension, performance " "Warning! SrcDstVectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop"); "would drop");
// separate steps into linear and non-linear components, accoording to dst tensor // separate steps into linear and non-linear components, accoording to dst tensor
...@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
// Loop over VectorReadWriteDim, and load data from src to the // Loop over SrcDstVectorReadWriteDim, and load data from src to the
// long-vector buffer. // long-vector buffer.
// If VectorReadWriteDim is dst's linear dimension, then dst's // If SrcDstVectorReadWriteDim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If // offset-diff due to this looping is known at compile-time. If
// VectorReadWriteDim is dst's nonlinear dimension, then dst's // SrcDstVectorReadWriteDim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best // offset-diff due to this looping is only known at run-time. For best
// performance, VectorReadWriteDim, should be dst's linear dimension // performance, SrcDstVectorReadWriteDim, should be dst's linear dimension
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
......
...@@ -83,13 +83,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -83,13 +83,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp" #include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T, template <class T,
class InDesc, class InDesc,
...@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{}; constexpr auto in_nchw_desc =
constexpr auto wei_kcyx_desc = WeiDesc{}; make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1); constexpr index_t K = out_nkhw_desc.GetLength(I1);
...@@ -51,12 +54,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -51,12 +54,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 1
// BlockSize = 256, EPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -65,35 +68,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -65,35 +68,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<4, 1>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 1; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// BlockSize = 256, EPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -102,35 +100,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -102,35 +100,30 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<1, 4>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t OutThreadCopyDataPerAccess_B = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 0
// BlockSize = 256, EPerBlock = 16 // BlockSize = 256, GemmKPerBlock = 16
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t EPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -139,34 +132,29 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -139,34 +132,29 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 4>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t OutThreadCopyDataPerAccess_B = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 1
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -175,41 +163,36 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -175,41 +163,36 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 2>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<4, 64>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 2; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t OutThreadCopyDataPerAccess_B = 2; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
#endif #endif
constexpr index_t B = N * Ho * Wo; constexpr index_t B = N * Ho * Wo;
constexpr index_t GridSize = constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); ((B + GemmNPerBlock - 1) / GemmNPerBlock) * ((K + GemmMPerBlock - 1) / GemmMPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw<
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
GridSize, GridSize,
BlockSize, BlockSize,
T, T,
T,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
...@@ -217,9 +200,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -217,9 +200,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
ConvDilations, ConvDilations,
LeftPads, LeftPads,
RightPads, RightPads,
BPerBlock, GemmMPerBlock,
KPerBlock, GemmNPerBlock,
EPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
...@@ -227,22 +210,17 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -227,22 +210,17 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, ThreadGemmDataPerReadM,
GemmDataPerReadB, ThreadGemmDataPerReadN,
InBlockCopySubLengths_E_B, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
InBlockCopyClusterLengths_E_B, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
InBlockCopyThreadClusterArrangeOrder, GemmABlockCopySrcDataPerRead_GemmK,
InBlockCopySrcAccessOrder, GemmABlockCopyDstDataPerWrite_GemmM,
InBlockCopyDstAccessOrder, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
InBlockCopyDataPerAccess_B, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
WeiBlockCopySubLengths_E_K, GemmBBlockCopySrcDataPerRead_GemmN,
WeiBlockCopyClusterLengths_E_K, GemmBBlockCopyDstDataPerWrite_GemmN,
WeiBlockCopyThreadClusterArrangeOrder, GemmCThreadCopyDstDataPerWrite_GemmN1>{};
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -21,7 +21,7 @@ int main(int argc, char* argv[]) ...@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
constexpr index_t N = 8; constexpr index_t N = 8;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 16; constexpr index_t HI = 16;
......
...@@ -43,7 +43,7 @@ int main(int argc, char* argv[]) ...@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -250,7 +250,7 @@ int main(int argc, char* argv[]) ...@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -296,7 +296,7 @@ int main(int argc, char* argv[]) ...@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -403,7 +403,7 @@ int main(int argc, char* argv[]) ...@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment