"...git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "1dd38d6cb1312966e41a7e772fd089e812e90cdb"
Commit 03b9544a authored by Chao Liu's avatar Chao Liu
Browse files

Merge branch 'tune_1121' into bwd_data

parents a7a1e3c1 528051d2
......@@ -10,7 +10,6 @@
#include "blockwise_gemm.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
......@@ -61,11 +60,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc =
......@@ -158,7 +152,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
InBlockCopyDataPerAccess_B,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
......@@ -192,7 +190,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
......@@ -202,7 +204,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
......@@ -260,10 +261,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(
p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
......@@ -294,10 +293,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
......@@ -323,10 +320,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
......@@ -397,17 +392,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1})
#if 1
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
#else // tweaking
.Run_optimized_dst_address_calculation(
p_out_thread, p_out_global, generic_address_space, global_address_space);
#endif
OutThreadCopyDataPerAccess_B,
AddressSpace::vgpr,
AddressSpace::global>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1})
.Run(p_out_thread, p_out_global);
}
}
};
......
......@@ -54,8 +54,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, each thread hold 64 data
#if 0
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
......@@ -89,6 +89,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
......
......@@ -51,6 +51,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// BlockSize = 256, EPerBlock = 8
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
......@@ -85,7 +86,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 1;
#elif 1
#elif 0
// BlockSize = 256, EPerBlock = 8
// 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256;
......@@ -122,6 +124,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
#elif 0
// BlockSize = 256, EPerBlock = 16
// 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 4>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
#elif 1
// 1x1 filter, 14x14 image
constexpr index_t BlockSize = 256;
......@@ -167,47 +206,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
#else
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
BPerBlock,
KPerBlock,
EPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_B,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
BPerBlock,
KPerBlock,
EPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_B,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i)
{
......
......@@ -58,7 +58,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment