Commit d3146496 authored by Jing Zhang's avatar Jing Zhang
Browse files

restore nchwc

parent 9f92c019
......@@ -4,7 +4,6 @@
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template <typename TInWei,
ck::index_t InWeiVectorSize,
typename TAcc,
typename TOut,
typename InLengths,
......@@ -49,7 +48,9 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3];
#if 0
constexpr auto InWeiVectorSize = 8;
#if 1
const auto C0 = C / Number<InWeiVectorSize>{};
const auto C1 = Number<InWeiVectorSize>{};
......
......@@ -102,7 +102,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nhwc_kyxc_nhwk(
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, 8>;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, 16, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
......
......@@ -24,7 +24,8 @@
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NHWC 1
#define USE_CONV_FWD_V5R1_NHWC 0
#define USE_CONV_FWD_V5R1_NCHWC 1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
......@@ -33,9 +34,10 @@ enum ConvForwardAlgo
V4R4NCHW, // 0
V4R4R2NHWC, // 1
V6R1NCHW, // 2
V5R1NHWC, // 3
V4R4R2XDLNCHW, // 4
V4R4R4XDLNHWC // 5
V5R1NCHWc, // 3
V5R1NHWC, // 4
V4R4R2XDLNCHW, // 5
V4R4R4XDLNHWC // 6
};
int main(int argc, char* argv[])
......@@ -342,6 +344,32 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWc)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V5R1_NHWC
if(algo == ConvForwardAlgo::V5R1NHWC)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment