"vscode:/vscode.git/clone" did not exist on "1410850eccabfb0b41cc3d92925aa29d9c071974"
Commit 5b36aead authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent cc0fa73a
...@@ -84,10 +84,10 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -84,10 +84,10 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8; constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4; constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
...@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 1
// for 3x3, 56x56 // for 3x3, 56x56
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
...@@ -209,10 +209,26 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -209,10 +209,26 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -248,7 +264,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -248,7 +264,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 0
// for 1x1, 14x14, Pascal // for 1x1, 14x14, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -290,7 +306,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -290,7 +306,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 #if 1
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
#else #else
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
......
...@@ -421,7 +421,7 @@ int main(int argc, char* argv[]) ...@@ -421,7 +421,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3, 56x56 // 3x3, 56x56
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -430,6 +430,9 @@ int main(int argc, char* argv[]) ...@@ -430,6 +430,9 @@ int main(int argc, char* argv[])
constexpr index_t K = 64; constexpr index_t K = 64;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 58x58 // 3x3, 58x58
constexpr index_t N = 64; constexpr index_t N = 64;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment