Commit 0f6cb787 authored by Chao Liu's avatar Chao Liu
Browse files

double index buffer

parent 7d09790a
...@@ -189,11 +189,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -189,11 +189,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -237,11 +238,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -237,11 +238,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmKPerThreadLoop,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
ThreadGemmDataPerReadM, ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN, ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -29,13 +29,13 @@ int main(int argc, char* argv[]) ...@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
// 1x1 // 1x1
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 64; constexpr index_t C = 512;
constexpr index_t HI = 56; constexpr index_t HI = 28;
constexpr index_t WI = 56; constexpr index_t WI = 28;
constexpr index_t K = 256; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -44,13 +44,13 @@ int main(int argc, char* argv[]) ...@@ -44,13 +44,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x7 // 1x7
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -281,7 +281,7 @@ int main(int argc, char* argv[]) ...@@ -281,7 +281,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -327,7 +327,7 @@ int main(int argc, char* argv[]) ...@@ -327,7 +327,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 1 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -434,7 +434,7 @@ int main(int argc, char* argv[]) ...@@ -434,7 +434,7 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment