Commit 90acba1d authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a22c7cf5
...@@ -44,7 +44,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -44,7 +44,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
4, // NXdlPerWave 4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, 1 // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferDstScalarPerVector_K1
...@@ -69,10 +69,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -69,10 +69,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle< using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NWK,
ck::Tuple<>, ck::Tuple<>,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,
......
...@@ -181,7 +181,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -181,7 +181,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto GetWeightTensorDescriptor(index_t GemmNRaw, index_t GemmKRaw) template <
typename BLayout_,
typename std::enable_if<is_same_v<BLayout_, ck::tensor_layout::convolution::KXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KYXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
{ {
const auto wei_k_yxc_grid_desc = const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw)); make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
...@@ -192,7 +198,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -192,7 +198,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return wei_gemmn_gemmk_grid_desc; return wei_gemmn_gemmk_grid_desc;
} }
static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN) template <
typename ELayout_,
typename std::enable_if<is_same_v<ELayout_, ck::tensor_layout::convolution::NWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NHWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{ {
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
...@@ -205,8 +217,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -205,8 +217,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return out_gemmm_gemmn_grid_desc; return out_gemmm_gemmn_grid_desc;
} }
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> template <typename ALayout_,
static auto GetInputTensorDescriptor(index_t N, typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C, index_t C,
index_t GemmMRaw, index_t GemmMRaw,
index_t GemmKRaw, index_t GemmKRaw,
...@@ -299,8 +313,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -299,8 +313,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
template <index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> template <typename ALayout_,
static auto GetInputTensorDescriptor(index_t N, typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NHWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C, index_t C,
index_t GemmMRaw, index_t GemmMRaw,
index_t GemmKRaw, index_t GemmKRaw,
...@@ -409,8 +425,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -409,8 +425,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
template <index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> template <typename ALayout_,
static auto GetInputTensorDescriptor(index_t N, typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NDHWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C, index_t C,
index_t GemmMRaw, index_t GemmMRaw,
index_t GemmKRaw, index_t GemmKRaw,
...@@ -574,7 +592,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -574,7 +592,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// A: // A:
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
GetInputTensorDescriptor<NDimSpatial>(N, MakeAGridDescriptor_M_K<ALayout>(N,
C, C,
GemmMRaw, GemmMRaw,
GemmKRaw, GemmKRaw,
...@@ -587,10 +605,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -587,10 +605,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
input_right_pads); input_right_pads);
// B: // B:
const auto wei_gemmn_gemmk_grid_desc = GetWeightTensorDescriptor(GemmNRaw, GemmKRaw); const auto wei_gemmn_gemmk_grid_desc = MakeBGridDescriptor_N_K<BLayout>(GemmNRaw, GemmKRaw);
// E: // E:
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmNRaw); const auto out_gemmm_gemmn_grid_desc = MakeEGridDescriptor_M_N<ELayout>(GemmMRaw, GemmNRaw);
return make_tuple( return make_tuple(
in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc); in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment