Commit 545d9305 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 37f4e2b6
...@@ -33,9 +33,18 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc, ...@@ -33,9 +33,18 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
#if 1
constexpr auto in_nchw_desc = InDesc{}; constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
#else
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLegnths(), OutDesc::GetStrides());
#endif
constexpr index_t N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1); constexpr index_t K = out_nkhw_desc.GetLength(I1);
...@@ -88,7 +97,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc, ...@@ -88,7 +97,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// BlockSize = 64, each thread hold 64 data // BlockSize = 64, each thread hold 64 data
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -125,6 +134,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc, ...@@ -125,6 +134,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 2;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif #endif
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp" #include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp" //#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
template <class T, template <class T,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp"
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
...@@ -76,7 +76,7 @@ int main(int argc, char* argv[]) ...@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
#if 0 #if 0
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 64; constexpr index_t C = 256;
constexpr index_t HI = 56; constexpr index_t HI = 56;
constexpr index_t WI = 56; constexpr index_t WI = 56;
constexpr index_t K = 256; constexpr index_t K = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment