Commit ed966e7f authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent d3146496
......@@ -78,7 +78,7 @@
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
template <typename TInWei,
typename TAcc,
......@@ -48,7 +48,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3];
constexpr auto InWeiVectorSize = 8;
constexpr auto InWeiVectorSize = 4;
#if 1
const auto C0 = C / Number<InWeiVectorSize>{};
......@@ -106,16 +106,16 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 8;
constexpr index_t EPerBlock = 2;
constexpr index_t E1 = 4 * 9;
constexpr index_t E2 = C1;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, 8>;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, 16, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
......@@ -123,7 +123,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = 8;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#endif
constexpr auto conv_driver =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment