Commit 33975236 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 9526b9ec
......@@ -12,14 +12,12 @@ namespace tensor_operation {
namespace device {
// Convolution Forward:
// input : input image A[N, Hi, Wi, C],
// input : weight B[K, Y, X, C],
// input : D0[N, Ho, Wo, K], D1[N, Ho, Wo, K], ...
// output : output image E[N, Ho, Wo, K]
// input : input image A[N, C, Hi, Wi],
// input : weight B[K, C, Y, X],
// input : D0[N, K, Ho, Wo], D1[N, K, Ho, Wo], ...
// output : output image E[N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
......
......@@ -84,6 +84,12 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
// input : input image A[N, C, Hi, Wi],
// input : weight B[K, C, Y, X],
// input : D0[N, K, Ho, Wo], D1[N, K, Ho, Wo], ...
// output : output image E[N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -166,6 +172,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
BElementwiseOperation,
CDEElementwiseOperation>
{
namespace ctc = ck::tensor_layout::convolution;
using DeviceOp = DeviceConvFwdMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -181,9 +189,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NWC>,
bool>::type = false>
template <typename ALay, typename std::enable_if<is_same_v<ALay, ctc::NWC>, bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -293,8 +299,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NHWC>,
bool>::type = false>
typename std::enable_if<is_same_v<ALay, ctc::NHWC>, bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -418,8 +423,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NDHWC>,
bool>::type = false>
typename std::enable_if<is_same_v<ALay, ctc::NDHWC>, bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -566,9 +570,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// KYXC, K_YXC
// KZYXC, K_ZYXC
template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::KXC> ||
is_same_v<BLay, tensor_layout::convolution::KYXC> ||
is_same_v<BLay, tensor_layout::convolution::KZYXC>,
typename std::enable_if<is_same_v<BLay, ctc::KXC> || is_same_v<BLay, ctc::KYXC> ||
is_same_v<BLay, ctc::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
{
......@@ -582,9 +585,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::NWK> ||
is_same_v<ELay, tensor_layout::convolution::NHWK> ||
is_same_v<ELay, tensor_layout::convolution::NDHWK>,
typename std::enable_if<is_same_v<ELay, ctc::NWK> || is_same_v<ELay, ctc::NHWK> ||
is_same_v<ELay, ctc::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{
......@@ -929,8 +931,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx908")
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment