Commit 9c3f33c0 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 4d93ce0e
...@@ -13,15 +13,14 @@ template <class TInWei, ...@@ -13,15 +13,14 @@ template <class TInWei,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads, class InRightPads>
class T>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
InDesc, InDesc,
const Tensor<T>& in_n_c_hi_wi, const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc, WeiDesc,
const Tensor<T>& wei_k_c_y_x, const Tensor<TInWei>& wei_k_c_y_x,
OutDesc, OutDesc,
Tensor<T>& out_n_k_ho_wo, Tensor<TOut>& out_n_k_ho_wo,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
...@@ -374,7 +373,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -374,7 +373,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
#if 1 #if 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
......
...@@ -21,21 +21,7 @@ int main(int argc, char* argv[]) ...@@ -21,21 +21,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1;
constexpr index_t WI = 64;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -149,7 +135,7 @@ int main(int argc, char* argv[]) ...@@ -149,7 +135,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -632,9 +618,9 @@ int main(int argc, char* argv[]) ...@@ -632,9 +618,9 @@ int main(int argc, char* argv[])
#if 0 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using out_data_t = float;
using acc_data_t = float; using acc_data_t = float;
#else using out_data_t = float;
#elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
constexpr index_t in_vector_size = 4; constexpr index_t in_vector_size = 4;
using acc_data_t = int32_t; using acc_data_t = int32_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment