Commit 90acba1d authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a22c7cf5
......@@ -44,7 +44,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 0, 2>, 1 // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
......@@ -69,10 +69,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::Tuple<>,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
AccDataType,
......
......@@ -181,7 +181,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto GetWeightTensorDescriptor(index_t GemmNRaw, index_t GemmKRaw)
template <
typename BLayout_,
typename std::enable_if<is_same_v<BLayout_, ck::tensor_layout::convolution::KXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KYXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
{
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
......@@ -192,7 +198,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return wei_gemmn_gemmk_grid_desc;
}
static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN)
template <
typename ELayout_,
typename std::enable_if<is_same_v<ELayout_, ck::tensor_layout::convolution::NWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NHWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
......@@ -205,8 +217,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return out_gemmm_gemmn_grid_desc;
}
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
template <typename ALayout_,
typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmKRaw,
......@@ -299,8 +313,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
template <index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
template <typename ALayout_,
typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NHWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmKRaw,
......@@ -409,8 +425,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
template <index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(index_t N,
template <typename ALayout_,
typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NDHWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmKRaw,
......@@ -574,7 +592,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// A:
const auto in_gemmm_gemmk_grid_desc =
GetInputTensorDescriptor<NDimSpatial>(N,
MakeAGridDescriptor_M_K<ALayout>(N,
C,
GemmMRaw,
GemmKRaw,
......@@ -587,10 +605,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
input_right_pads);
// B:
const auto wei_gemmn_gemmk_grid_desc = GetWeightTensorDescriptor(GemmNRaw, GemmKRaw);
const auto wei_gemmn_gemmk_grid_desc = MakeBGridDescriptor_N_K<BLayout>(GemmNRaw, GemmKRaw);
// E:
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmNRaw);
const auto out_gemmm_gemmn_grid_desc = MakeEGridDescriptor_M_N<ELayout>(GemmMRaw, GemmNRaw);
return make_tuple(
in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment