Commit 7cc806d8 authored by Chao Liu's avatar Chao Liu
Browse files

add matrix padder

parent 0b997ce4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
......
......@@ -181,51 +181,54 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k)
static auto GetWeightTensorDescriptor(index_t GemmN, index_t GemmK)
{
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const auto wei_k_yxe_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const index_t GemmK0 = GemmK / GemmK1Number;
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmN, GemmK));
// wei_gemmk0_gemmn_gemmk1_grid_desc
return transform_tensor_descriptor(
wei_k_yxe_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_n)),
wei_k_yxc_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto
GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n, ck::index_t gemm_m_pad)
static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmN));
// out_gemmm_gemmn_grid_desc
return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(gemm_m, gemm_m_pad),
make_pass_through_transform(gemm_n)),
make_tuple(make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmK,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const index_t GemmK0 = GemmK / GemmK1Number;
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
......@@ -234,13 +237,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(GemmMRaw, GemmK));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_right_pad_transform(gemm_m, gemm_m_pad)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_right_pad_transform(GemmMRaw, GemmMPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
......@@ -260,7 +263,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<2>{}, Sequence<0, 1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -268,8 +271,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
......@@ -310,37 +313,39 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_m)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
template <index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmK,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const index_t GemmK0 = GemmK / GemmK1Number;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
......@@ -354,13 +359,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmK));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_right_pad_transform(gemm_m, gemm_m_pad)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, GemmMPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
......@@ -381,7 +386,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_ho_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -389,8 +394,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
......@@ -440,8 +445,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_m)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -449,29 +454,31 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
template <index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmK,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMPad = GemmM - GemmMRaw;
const index_t GemmK0 = GemmK / GemmK1Number;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
......@@ -488,13 +495,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmK));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_right_pad_transform(gemm_m, gemm_m_pad)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, GemmMPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
......@@ -518,7 +525,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}),
......@@ -527,8 +534,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
......@@ -591,68 +598,65 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_m)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
static index_t GetGemmMRaw(ck::index_t N,
const std::vector<ck::index_t>& output_spatial_lengths)
static index_t GetGemmMRaw(index_t N, const std::vector<index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
std::multiplies<index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
static index_t GetGemmKRaw(index_t C, const std::vector<index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
std::multiplies<index_t>());
}
static auto
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
{
using namespace ck;
const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths);
const index_t GemmN = K;
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
const index_t GemmNRaw = K;
const index_t GemmKRaw = GetGemmKRaw(C, filter_spatial_lengths);
const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw;
assert(GemmK % GemmK1Number == 0);
// TODO: remove
assert(GemmKRaw % GemmK1Number == 0);
// A:
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
GetInputTensorDescriptor<NDimSpatial>(N,
C,
GemmMRaw,
GemmK,
GemmMPad,
GemmKRaw,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
......@@ -660,31 +664,34 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK);
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
GetWeightTensorDescriptor(GemmNRaw, GemmKRaw);
// E:
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad);
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmNRaw);
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABEGridDesc()
{
return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
template <index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABEGridDesc()
{
return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
template <index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABEGridDesc()
{
return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
......@@ -756,16 +763,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Argument(const ADataType* p_in_grid,
const BDataType* p_wei_grid,
EDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op)
......@@ -988,7 +995,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(ck::index_t i = 0; i < NDimSpatial; ++i)
for(index_t i = 0; i < NDimSpatial; ++i)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
......@@ -1001,7 +1008,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(ck::index_t i = 0; i < NDimSpatial; ++i)
for(index_t i = 0; i < NDimSpatial; ++i)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
......@@ -1040,16 +1047,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static auto MakeArgument(const ADataType* p_in_grid,
const BDataType* p_wei_grid,
EDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op)
......@@ -1078,16 +1085,16 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op) override
......
......@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
......@@ -160,6 +161,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -175,27 +179,15 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
......@@ -205,63 +197,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
......@@ -278,27 +213,15 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
......@@ -308,67 +231,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
......@@ -381,47 +247,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto e_grid_desc_m_n = matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
return e_grid_desc_m_n;
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// M/N/KPerTileType could be index_t or Number<>
template <GemmSpecialization GemmSpec,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
return transform_tensor_descriptor(a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or K
return a_desc_mraw_kraw;
}
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
return transform_tensor_descriptor(b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad N or K
return b_desc_nraw_kraw;
}
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_desc_mraw_nraw;
}
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment