device_dynamic_col2im_gemmkgemmn_nchw.hpp 5.57 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dynamic_gridwise_col2im_gemmkgemmn_nchw.hpp"

template <typename T,
          typename ColDesc,
          typename ImgDesc,
          typename FilterSizes,
          typename OutputSizes,
          typename ConvStrides,
          typename ConvDilations,
          typename InLeftPads,
          typename InRightPads>
void device_dynamic_col2im_gemmkgemmn_nchw(ColDesc,
                                           const Tensor<T>& col_gemmk_gemmn,
                                           ImgDesc,
                                           Tensor<T>& img_n_c_hi_wi,
                                           FilterSizes,
                                           OutputSizes,
                                           ConvStrides,
                                           ConvDilations,
                                           InLeftPads,
                                           InRightPads,
                                           std::size_t nrepeat)
{
    using namespace ck;

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    std::size_t data_sz = sizeof(T);
    DeviceMem col_gemmk_gemmn_device_buf(data_sz * col_gemmk_gemmn.mDesc.GetElementSpace());
    DeviceMem img_n_c_hi_wi_device_buf(data_sz * img_n_c_hi_wi.mDesc.GetElementSpace());

    col_gemmk_gemmn_device_buf.ToDevice(col_gemmk_gemmn.mData.data());
    img_n_c_hi_wi_device_buf.ToDevice(img_n_c_hi_wi.mData.data());

    const auto col_gemmk_gemmn_desc = make_dynamic_native_tensor_descriptor<2>(
        to_multi_index(ColDesc::GetLengths()), to_multi_index(ColDesc::GetStrides()));

    const auto img_n_c_hi_wi_desc = make_dynamic_native_tensor_descriptor<4>(
        to_multi_index(ImgDesc::GetLengths()), to_multi_index(ImgDesc::GetStrides()));

    const auto filter_sizes   = to_multi_index(FilterSizes{});
    const auto out_sizes      = to_multi_index(OutputSizes{});
    const auto conv_strides   = to_multi_index(ConvStrides{});
    const auto conv_dilations = to_multi_index(ConvDilations{});
    const auto in_left_pads   = to_multi_index(InLeftPads{});
    const auto in_right_pads  = to_multi_index(InRightPads{});

    const auto img_gemmk_gemmn_desc = map_img_into_col(img_n_c_hi_wi_desc,
                                                       out_sizes,
                                                       filter_sizes,
                                                       conv_strides,
                                                       conv_dilations,
                                                       in_left_pads,
                                                       in_right_pads);

    const index_t GemmN = col_gemmk_gemmn_desc.GetLength(I1);

#if 1
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmKPerBlock = 128;
    constexpr index_t GemmNPerBlock = 128;

    using BlockCopySubLengths_GemmK_GemmN     = Sequence<8, 8>;
    using BlockCopyClusterLengths_GemmK_GemmN = Sequence<16, 16>;
    using BlockCopyThreadClusterArrangeOrder  = Sequence<0, 1>; // [GemmK, GemmN]
    using BlockCopySrcAccessOrder             = Sequence<0, 1>; // [GemmK, GemmN]
    using BlockCopyDstAccessOrder             = Sequence<0, 1>; // [GemmK, GemmN]

    constexpr index_t BlockCopyDataPerAccess_GemmN = 1;
#endif

    const index_t GridSize = GemmN / GemmNPerBlock;

    printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);

    constexpr auto gridwise_col2im =
        DynamicGridwiseCol2Im_gemmkgemmn_nchw<BlockSize,
                                              GemmKPerBlock,
                                              GemmNPerBlock,
                                              BlockCopySubLengths_GemmK_GemmN,
                                              BlockCopyClusterLengths_GemmK_GemmN,
                                              BlockCopyThreadClusterArrangeOrder,
                                              BlockCopySrcAccessOrder,
                                              BlockCopyDstAccessOrder,
                                              BlockCopyDataPerAccess_GemmN>{};

    for(index_t i = 0; i < 1; ++i)
    {
        std::cout << "Start running " << nrepeat << " times..." << std::endl;

        KernelTimer timer;
        timer.Start();

        for(index_t j = 0; j < nrepeat; ++j)
        {
            launch_kernel(run_gridwise_operation<decltype(gridwise_col2im),
                                                 const T* const __restrict__,
                                                 T* const __restrict__,
                                                 decltype(col_gemmk_gemmn_desc),
                                                 decltype(img_gemmk_gemmn_desc)>,
                          dim3(GridSize),
                          dim3(BlockSize),
                          0,
                          0,
                          const_cast<const T* const __restrict__>(
                              static_cast<T*>(col_gemmk_gemmn_device_buf.GetDeviceBuffer())),
                          const_cast<T* const __restrict__>(
                              static_cast<T*>(img_n_c_hi_wi_device_buf.GetDeviceBuffer())),
                          col_gemmk_gemmn_desc,
                          img_gemmk_gemmn_desc);
        }

        timer.End();

        float ave_time = timer.GetElapsedTime() / nrepeat;

        std::cout << "Average time : " << ave_time << " ms" << std::endl;
    }

    img_n_c_hi_wi_device_buf.FromDevice(img_n_c_hi_wi.mData.data());
}