Commit 3474c777 authored by Chao Liu's avatar Chao Liu
Browse files

add gemm padding to convnd

parent 7cc806d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
......@@ -20,10 +18,10 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
#if 0
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
#if 0
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl<
NDimSpatial, //
......@@ -63,6 +61,11 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
#else
using CShuffleDataType = ck::half_t;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance =
ck::tensor_operation::device::DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle<
......@@ -76,7 +79,8 @@ using DeviceConvNDFwdInstance =
InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
......
......@@ -15,6 +15,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
......@@ -118,6 +120,7 @@ template <index_t NDimSpatial,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -181,15 +184,25 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
static auto GetWeightTensorDescriptor(index_t GemmN, index_t GemmK)
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto GetWeightTensorDescriptor(index_t GemmNRaw, index_t GemmKRaw)
{
const index_t GemmK0 = GemmK / GemmK1Number;
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmN, GemmK));
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
const auto GemmN = wei_gemmn_gemmk_grid_desc.GetLength(I0);
const auto GemmK = wei_gemmn_gemmk_grid_desc.GetLength(I1);
const index_t GemmK0 = GemmK / GemmK1Number;
// wei_gemmk0_gemmn_gemmk1_grid_desc
return transform_tensor_descriptor(
wei_k_yxc_grid_desc,
wei_gemmn_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -198,25 +211,22 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmN));
// out_gemmm_gemmn_grid_desc
return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmn_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmK,
index_t GemmKRaw,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
......@@ -225,10 +235,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const index_t GemmK0 = GemmK / GemmK1Number;
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
......@@ -237,45 +243,60 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmMRaw, GemmK));
make_naive_tensor_descriptor_packed(make_tuple(GemmMRaw, GemmKRaw));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_right_pad_transform(GemmMRaw, GemmMPad)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_wi_e_grid_desc =
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_e_grid_desc = transform_tensor_descriptor(
in_n_wi_e_grid_desc,
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_wo_e_grid_desc,
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<2>{}, Sequence<0, 1>{}),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
}
else
{
......@@ -284,19 +305,19 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_e_grid_desc =
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_e_grid_desc = transform_tensor_descriptor(
in_n_wi_e_grid_desc,
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_e_grid_desc = transform_tensor_descriptor(
in_n_wip_e_grid_desc,
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
......@@ -304,28 +325,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_x_wo_e_grid_desc,
make_tuple(make_merge_transform(make_tuple(X, C)),
make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_pass_through_transform(GemmM),
make_unmerge_transform(make_tuple(GemmK0, GemmK1Number))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
}
}
......@@ -333,7 +355,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static auto GetInputTensorDescriptor(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmK,
index_t GemmKRaw,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
......@@ -342,12 +364,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const index_t GemmK0 = GemmK / GemmK1Number;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
......@@ -358,25 +376,33 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmK));
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmMRaw, GemmKRaw));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, GemmMPad)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_hi_wi_e_grid_desc =
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_e_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_e_grid_desc,
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
......@@ -384,21 +410,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_ho_wo_e_grid_desc,
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
}
else
{
......@@ -414,11 +448,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_e_grid_desc =
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_e_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_e_grid_desc,
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
......@@ -426,8 +460,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_e_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_e_grid_desc,
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
......@@ -436,29 +470,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_e_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_pass_through_transform(GemmM),
make_unmerge_transform(make_tuple(GemmK0, GemmK1Number))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
}
}
......@@ -466,7 +500,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static auto GetInputTensorDescriptor(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmK,
index_t GemmKRaw,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
......@@ -475,13 +509,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const index_t GemmK0 = GemmK / GemmK1Number;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
......@@ -494,25 +524,33 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmK));
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmMRaw, GemmKRaw));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, GemmMPad)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_e_grid_desc =
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_e_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_e_grid_desc,
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
......@@ -523,22 +561,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}),
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
}
else
{
......@@ -558,11 +603,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_e_grid_desc =
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_e_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_e_grid_desc,
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
......@@ -573,8 +618,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_e_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_e_grid_desc,
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
......@@ -589,28 +634,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_e_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0);
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_pass_through_transform(GemmM),
make_unmerge_transform(make_tuple(GemmK0, GemmK1Number))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
}
}
......@@ -871,12 +917,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment