Commit 90acba1d authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a22c7cf5
...@@ -44,7 +44,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -44,7 +44,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
4, // NXdlPerWave 4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, 1 // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferDstScalarPerVector_K1
...@@ -69,10 +69,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -69,10 +69,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle< using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NWK,
ck::Tuple<>, ck::Tuple<>,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,
......
...@@ -181,7 +181,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -181,7 +181,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto GetWeightTensorDescriptor(index_t GemmNRaw, index_t GemmKRaw) template <
typename BLayout_,
typename std::enable_if<is_same_v<BLayout_, ck::tensor_layout::convolution::KXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KYXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
{ {
const auto wei_k_yxc_grid_desc = const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw)); make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
...@@ -192,7 +198,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -192,7 +198,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return wei_gemmn_gemmk_grid_desc; return wei_gemmn_gemmk_grid_desc;
} }
static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN) template <
typename ELayout_,
typename std::enable_if<is_same_v<ELayout_, ck::tensor_layout::convolution::NWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NHWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{ {
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
...@@ -205,18 +217,20 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -205,18 +217,20 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return out_gemmm_gemmn_grid_desc; return out_gemmm_gemmn_grid_desc;
} }
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> template <typename ALayout_,
static auto GetInputTensorDescriptor(index_t N, typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NWC>,
index_t C, bool>::type = false>
index_t GemmMRaw, static auto MakeAGridDescriptor_M_K(index_t N,
index_t GemmKRaw, index_t C,
const std::vector<index_t>& input_spatial_lengths, index_t GemmMRaw,
const std::vector<index_t>& filter_spatial_lengths, index_t GemmKRaw,
const std::vector<index_t>& output_spatial_lengths, const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& conv_filter_strides, const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& conv_filter_dilations, const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& input_left_pads, const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& input_right_pads) const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{ {
const index_t Wi = input_spatial_lengths[0]; const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0]; const index_t Wo = output_spatial_lengths[0];
...@@ -299,18 +313,20 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -299,18 +313,20 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
template <index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> template <typename ALayout_,
static auto GetInputTensorDescriptor(index_t N, typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NHWC>,
index_t C, bool>::type = false>
index_t GemmMRaw, static auto MakeAGridDescriptor_M_K(index_t N,
index_t GemmKRaw, index_t C,
const std::vector<index_t>& input_spatial_lengths, index_t GemmMRaw,
const std::vector<index_t>& filter_spatial_lengths, index_t GemmKRaw,
const std::vector<index_t>& output_spatial_lengths, const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& conv_filter_strides, const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& conv_filter_dilations, const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& input_left_pads, const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& input_right_pads) const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{ {
const index_t Hi = input_spatial_lengths[0]; const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1]; const index_t Wi = input_spatial_lengths[1];
...@@ -409,18 +425,20 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -409,18 +425,20 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
template <index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> template <typename ALayout_,
static auto GetInputTensorDescriptor(index_t N, typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NDHWC>,
index_t C, bool>::type = false>
index_t GemmMRaw, static auto MakeAGridDescriptor_M_K(index_t N,
index_t GemmKRaw, index_t C,
const std::vector<index_t>& input_spatial_lengths, index_t GemmMRaw,
const std::vector<index_t>& filter_spatial_lengths, index_t GemmKRaw,
const std::vector<index_t>& output_spatial_lengths, const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& conv_filter_strides, const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& conv_filter_dilations, const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& input_left_pads, const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& input_right_pads) const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{ {
const index_t Di = input_spatial_lengths[0]; const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1]; const index_t Hi = input_spatial_lengths[1];
...@@ -574,23 +592,23 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -574,23 +592,23 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// A: // A:
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
GetInputTensorDescriptor<NDimSpatial>(N, MakeAGridDescriptor_M_K<ALayout>(N,
C, C,
GemmMRaw, GemmMRaw,
GemmKRaw, GemmKRaw,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
// B: // B:
const auto wei_gemmn_gemmk_grid_desc = GetWeightTensorDescriptor(GemmNRaw, GemmKRaw); const auto wei_gemmn_gemmk_grid_desc = MakeBGridDescriptor_N_K<BLayout>(GemmNRaw, GemmKRaw);
// E: // E:
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmNRaw); const auto out_gemmm_gemmn_grid_desc = MakeEGridDescriptor_M_N<ELayout>(GemmMRaw, GemmNRaw);
return make_tuple( return make_tuple(
in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc); in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment