Commit 03b9544a authored by Chao Liu's avatar Chao Liu
Browse files

Merge branch 'tune_1121' into bwd_data

parents a7a1e3c1 528051d2
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
namespace ck { namespace ck {
// B = merge(N, Ho, Wo) // B = merge(N, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -61,11 +60,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -61,11 +60,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
constexpr auto in_n_c_hi_wi_global_desc = constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc = constexpr auto wei_k_c_y_x_global_desc =
...@@ -158,7 +152,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -158,7 +152,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
1, 1,
1, 1,
InBlockCopyDataPerAccess_B, InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>( InBlockCopyDataPerAccess_B,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
...@@ -192,7 +190,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -192,7 +190,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -202,7 +204,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -202,7 +204,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc); constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check // sanity check
...@@ -260,10 +261,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -260,10 +261,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.Run( blockwise_in_copy.Run(p_in_global, p_in_block_double);
p_in_global, p_in_block_double, global_address_space, generic_address_space); blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -294,10 +293,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -294,10 +293,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer( blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -323,10 +320,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -323,10 +320,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer( blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
...@@ -397,17 +392,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -397,17 +392,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
3, 3,
OutThreadCopyDataPerAccess_B, OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0}, OutThreadCopyDataPerAccess_B,
{k_thread_data_on_global / K1, AddressSpace::vgpr,
k_thread_data_on_global % K1, AddressSpace::global>({0, 0, 0, 0},
b_thread_data_on_global / B1, {k_thread_data_on_global / K1,
b_thread_data_on_global % B1}) k_thread_data_on_global % K1,
#if 1 b_thread_data_on_global / B1,
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space); b_thread_data_on_global % B1})
#else // tweaking .Run(p_out_thread, p_out_global);
.Run_optimized_dst_address_calculation(
p_out_thread, p_out_global, generic_address_space, global_address_space);
#endif
} }
} }
}; };
......
...@@ -54,8 +54,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -54,8 +54,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, each thread hold 64 data // BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16; constexpr index_t BPerBlock = 16;
...@@ -89,6 +89,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -89,6 +89,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 0
......
...@@ -51,6 +51,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -51,6 +51,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 1
// BlockSize = 256, EPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128; constexpr index_t BPerBlock = 128;
...@@ -85,7 +86,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -85,7 +86,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 1; constexpr index_t OutThreadCopyDataPerAccess_B = 1;
#elif 1 #elif 0
// BlockSize = 256, EPerBlock = 8
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -122,6 +124,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -122,6 +124,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t OutThreadCopyDataPerAccess_B = 4; constexpr index_t OutThreadCopyDataPerAccess_B = 4;
#elif 0 #elif 0
// BlockSize = 256, EPerBlock = 16
// 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 4>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
#elif 1
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -167,47 +206,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -167,47 +206,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded GridSize,
#else BlockSize,
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer T,
#endif decltype(in_nchw_desc),
<GridSize, decltype(wei_kcyx_desc),
BlockSize, decltype(out_nkhw_desc),
T, ConvStrides,
decltype(in_nchw_desc), ConvDilations,
decltype(wei_kcyx_desc), LeftPads,
decltype(out_nkhw_desc), RightPads,
ConvStrides, BPerBlock,
ConvDilations, KPerBlock,
LeftPads, EPerBlock,
RightPads, GemmMPerThreadSubC,
BPerBlock, GemmNPerThreadSubC,
KPerBlock, GemmMLevel0Cluster,
EPerBlock, GemmNLevel0Cluster,
GemmMPerThreadSubC, GemmMLevel1Cluster,
GemmNPerThreadSubC, GemmNLevel1Cluster,
GemmMLevel0Cluster, GemmKPerThreadLoop,
GemmNLevel0Cluster, GemmDataPerReadA,
GemmMLevel1Cluster, GemmDataPerReadB,
GemmNLevel1Cluster, InBlockCopySubLengths_E_B,
GemmKPerThreadLoop, InBlockCopyClusterLengths_E_B,
GemmDataPerReadA, InBlockCopyThreadClusterArrangeOrder,
GemmDataPerReadB, InBlockCopySrcAccessOrder,
InBlockCopySubLengths_E_B, InBlockCopyDstAccessOrder,
InBlockCopyClusterLengths_E_B, InBlockCopyDataPerAccess_B,
InBlockCopyThreadClusterArrangeOrder, WeiBlockCopySubLengths_E_K,
InBlockCopySrcAccessOrder, WeiBlockCopyClusterLengths_E_K,
InBlockCopyDstAccessOrder, WeiBlockCopyThreadClusterArrangeOrder,
InBlockCopyDataPerAccess_B, WeiBlockCopySrcAccessOrder,
WeiBlockCopySubLengths_E_K, WeiBlockCopyDstAccessOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstDataPerWrite_K,
WeiBlockCopySrcAccessOrder, OutThreadCopyDataPerAccess_B>{};
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
......
...@@ -58,7 +58,7 @@ int main(int argc, char* argv[]) ...@@ -58,7 +58,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64; constexpr index_t N = 64;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment