Commit f7be86b9 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent d7079939
......@@ -158,12 +158,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
#if 0
BlockwiseGenericTensorSliceCopy_v1
#else
BlockwiseGenericTensorSliceCopy_v2
#endif
<BlockSize,
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
......@@ -175,7 +170,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
......@@ -192,12 +188,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
#if 0
BlockwiseGenericTensorSliceCopy_v1
#else
BlockwiseGenericTensorSliceCopy_v2
#endif
<BlockSize,
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
......@@ -209,7 +200,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......
......@@ -51,7 +51,7 @@ template <index_t GridSize,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
{
#if 1
#if 0
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
......@@ -437,6 +437,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// input
constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
......@@ -465,6 +466,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// weight
constexpr auto wei_e_k_global_desc =
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
......@@ -487,8 +495,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
print_array("idx1: ", idx1);
print_array("idx0: ", idx0);
}
#else
index_t itmp = get_block_1d_id() + get_thread_local_1d_id();
auto wei_coord1 = make_tensor_coordinate_v2(wei_e_k_global_desc, {itmp, itmp + 1});
auto step_sizes = make_multi_index(EPerBlock, 0);
wei_coord1 += step_sizes;
p_out_global[0] = wei_coord1.GetLowerCoordinate().GetIndex()[0];
p_out_global[1] = wei_coord1.GetLowerCoordinate().GetIndex()[1];
p_out_global[2] = wei_coord1.GetLowerCoordinate().GetIndex()[2];
p_out_global[3] = wei_coord1.GetLowerCoordinate().GetIndex()[3];
#endif
p_out_global[0] = in_e_n1_b_n2_global_desc.CalculateOffset({0, 0, 10, 0});
}
#endif
};
......
......@@ -197,7 +197,7 @@ struct Merge
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow, 1>{}([&](auto ireverse) {
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(carry)
......@@ -213,6 +213,12 @@ struct Merge
carry = true;
}
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
}
}
else if(idx_up_diff[0] < 0)
{
......@@ -220,7 +226,7 @@ struct Merge
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow, 1>{}([&](auto ireverse) {
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(borrow)
......@@ -236,6 +242,12 @@ struct Merge
borrow = true;
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
}
}
return idx_low_new - idx_low_old;
......
......@@ -70,7 +70,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
......
......@@ -74,7 +74,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
......
......@@ -74,12 +74,12 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 1
constexpr index_t N = 512;
constexpr index_t C = 16;
#if 0
constexpr index_t N = 256;
constexpr index_t C = 64;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 512;
constexpr index_t K = 256;
constexpr index_t Y = 17;
constexpr index_t X = 17;
......@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -117,8 +117,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
......@@ -133,8 +133,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
......@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
......@@ -165,8 +165,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
......@@ -181,8 +181,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
......@@ -197,8 +197,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
......@@ -213,8 +213,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
......@@ -229,25 +229,9 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr index_t N = 128;
......@@ -261,8 +245,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
......@@ -277,8 +261,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
......@@ -293,8 +277,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
......@@ -309,8 +293,24 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment