Commit 78af7d67 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 8d15144c
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
// AMD inline asm // AMD inline asm
#ifndef CK_USE_AMD_INLINE_ASM #ifndef CK_USE_AMD_INLINE_ASM
#define CK_USE_AMD_INLINE_ASM 0 #define CK_USE_AMD_INLINE_ASM 1
#endif #endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif #endif
// AMD buffer addressing // AMD buffer addressing
......
...@@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -563,7 +563,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -563,7 +563,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 64, 32x128x3 // cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -187,7 +187,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -187,7 +187,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -255,7 +255,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -255,7 +255,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4 // GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
...@@ -793,7 +793,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -793,7 +793,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 64, 32x128x3 // cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16.hpp" //#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
...@@ -128,17 +128,17 @@ int main(int argc, char* argv[]) ...@@ -128,17 +128,17 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 0 #elif 1
// 3x3, 299x299 stride=2 // 3x3, 299x299 stride=2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 3;
constexpr index_t HI = 14; constexpr index_t HI = 299;
constexpr index_t WI = 14; constexpr index_t WI = 299;
constexpr index_t K = 1024; constexpr index_t K = 32;
constexpr index_t Y = 1; constexpr index_t Y = 3;
constexpr index_t X = 1; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
...@@ -521,7 +521,7 @@ int main(int argc, char* argv[]) ...@@ -521,7 +521,7 @@ int main(int argc, char* argv[])
print_sequence("ConvStrides", ConvStrides{}); print_sequence("ConvStrides", ConvStrides{});
print_sequence("ConvDilations", ConvDilations{}); print_sequence("ConvDilations", ConvDilations{});
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
using out_data_t = float; using out_data_t = float;
#else #else
...@@ -606,7 +606,7 @@ int main(int argc, char* argv[]) ...@@ -606,7 +606,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -618,7 +618,7 @@ int main(int argc, char* argv[]) ...@@ -618,7 +618,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16(in_nchw_desc, device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment