device_dynamic_col2im_gemmkgemmn_nchw.hpp 5.57 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
6
#include "gridwise_dynamic_col2im_gemmkgemmn_nchw.hpp"
Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

template <typename T,
          typename ColDesc,
          typename ImgDesc,
          typename FilterSizes,
          typename OutputSizes,
          typename ConvStrides,
          typename ConvDilations,
          typename InLeftPads,
          typename InRightPads>
void device_dynamic_col2im_gemmkgemmn_nchw(ColDesc,
                                           const Tensor<T>& col_gemmk_gemmn,
                                           ImgDesc,
                                           Tensor<T>& img_n_c_hi_wi,
                                           FilterSizes,
                                           OutputSizes,
                                           ConvStrides,
                                           ConvDilations,
                                           InLeftPads,
                                           InRightPads,
                                           std::size_t nrepeat)
{
    using namespace ck;

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    std::size_t data_sz = sizeof(T);
    DeviceMem col_gemmk_gemmn_device_buf(data_sz * col_gemmk_gemmn.mDesc.GetElementSpace());
    DeviceMem img_n_c_hi_wi_device_buf(data_sz * img_n_c_hi_wi.mDesc.GetElementSpace());

    col_gemmk_gemmn_device_buf.ToDevice(col_gemmk_gemmn.mData.data());
    img_n_c_hi_wi_device_buf.ToDevice(img_n_c_hi_wi.mData.data());

43
    const auto col_gemmk_gemmn_desc = make_dynamic_naive_tensor_descriptor<2>(
Chao Liu's avatar
Chao Liu committed
44
45
        to_multi_index(ColDesc::GetLengths()), to_multi_index(ColDesc::GetStrides()));

46
    const auto img_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor<4>(
Chao Liu's avatar
Chao Liu committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        to_multi_index(ImgDesc::GetLengths()), to_multi_index(ImgDesc::GetStrides()));

    const auto filter_sizes   = to_multi_index(FilterSizes{});
    const auto out_sizes      = to_multi_index(OutputSizes{});
    const auto conv_strides   = to_multi_index(ConvStrides{});
    const auto conv_dilations = to_multi_index(ConvDilations{});
    const auto in_left_pads   = to_multi_index(InLeftPads{});
    const auto in_right_pads  = to_multi_index(InRightPads{});

    const auto img_gemmk_gemmn_desc = map_img_into_col(img_n_c_hi_wi_desc,
                                                       out_sizes,
                                                       filter_sizes,
                                                       conv_strides,
                                                       conv_dilations,
                                                       in_left_pads,
                                                       in_right_pads);

    const index_t GemmN = col_gemmk_gemmn_desc.GetLength(I1);

#if 1
    constexpr index_t BlockSize = 256;

Chao Liu's avatar
Chao Liu committed
69
    constexpr index_t GemmKPerBlock = 8;
Chao Liu's avatar
Chao Liu committed
70
71
    constexpr index_t GemmNPerBlock = 128;

Chao Liu's avatar
Chao Liu committed
72
73
    using BlockCopySubLengths_GemmK_GemmN     = Sequence<1, 8>;
    using BlockCopyClusterLengths_GemmK_GemmN = Sequence<8, 16>;
Chao Liu's avatar
Chao Liu committed
74
75
76
77
78
79
80
81
82
83
84
85
    using BlockCopyThreadClusterArrangeOrder  = Sequence<0, 1>; // [GemmK, GemmN]
    using BlockCopySrcAccessOrder             = Sequence<0, 1>; // [GemmK, GemmN]
    using BlockCopyDstAccessOrder             = Sequence<0, 1>; // [GemmK, GemmN]

    constexpr index_t BlockCopyDataPerAccess_GemmN = 1;
#endif

    const index_t GridSize = GemmN / GemmNPerBlock;

    printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);

    constexpr auto gridwise_col2im =
86
        GridwiseDynamicCol2Im_gemmkgemmn_nchw<BlockSize,
Chao Liu's avatar
Chao Liu committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                                              GemmKPerBlock,
                                              GemmNPerBlock,
                                              BlockCopySubLengths_GemmK_GemmN,
                                              BlockCopyClusterLengths_GemmK_GemmN,
                                              BlockCopyThreadClusterArrangeOrder,
                                              BlockCopySrcAccessOrder,
                                              BlockCopyDstAccessOrder,
                                              BlockCopyDataPerAccess_GemmN>{};

    for(index_t i = 0; i < 1; ++i)
    {
        std::cout << "Start running " << nrepeat << " times..." << std::endl;

        KernelTimer timer;
        timer.Start();

        for(index_t j = 0; j < nrepeat; ++j)
        {
            launch_kernel(run_gridwise_operation<decltype(gridwise_col2im),
                                                 const T* const __restrict__,
                                                 T* const __restrict__,
                                                 decltype(col_gemmk_gemmn_desc),
                                                 decltype(img_gemmk_gemmn_desc)>,
                          dim3(GridSize),
                          dim3(BlockSize),
                          0,
                          0,
                          const_cast<const T* const __restrict__>(
                              static_cast<T*>(col_gemmk_gemmn_device_buf.GetDeviceBuffer())),
                          const_cast<T* const __restrict__>(
                              static_cast<T*>(img_n_c_hi_wi_device_buf.GetDeviceBuffer())),
                          col_gemmk_gemmn_desc,
                          img_gemmk_gemmn_desc);
        }

        timer.End();

        float ave_time = timer.GetElapsedTime() / nrepeat;

        std::cout << "Average time : " << ave_time << " ms" << std::endl;
    }

    img_n_c_hi_wi_device_buf.FromDevice(img_n_c_hi_wi.mData.data());
}