host_conv.hpp 6.12 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "host_tensor.hpp"
3
#include "conv_common.hpp"
4

zjing14's avatar
zjing14 committed
5
6
7
8
9
10
11
template <typename TIn,
          typename TWei,
          typename TOut,
          typename ConvStrides,
          typename ConvDilations,
          typename InLeftPads,
          typename InRightPads>
12
13
14
15
16
17
18
void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
                              const Tensor<TWei>& wei,
                              Tensor<TOut>& out,
                              const ConvStrides& conv_strides,
                              const ConvDilations& conv_dilations,
                              const InLeftPads& in_left_pads,
                              const InRightPads&)
19
{
20
21
    constexpr auto I0 = ck::Number<0>{};
    constexpr auto I1 = ck::Number<1>{};
22

23
    auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
24
        float v = 0;
25
        for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
26
        {
27
            for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
28
            {
29
30
                int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
                for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
31
                {
32
33
34
                    int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
                    if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
                       wi < in.mDesc.GetLengths()[3])
35
                    {
36
37
                        v += ck::type_convert<float>(in(n, c, hi, wi)) *
                             ck::type_convert<float>(wei(k, c, y, x));
38
39
40
41
                    }
                }
            }
        }
42
        out(n, k, ho, wo) = ck::type_convert<TOut>(v);
43
44
    };

45
46
47
48
49
    make_ParallelTensorFunctor(f_nchw,
                               out.mDesc.GetLengths()[0],
                               out.mDesc.GetLengths()[1],
                               out.mDesc.GetLengths()[2],
                               out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
50
}
Jianfeng Yan's avatar
Jianfeng Yan committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

template <typename TIn,
          typename TWei,
          typename TOut,
          typename ConvStrides,
          typename ConvDilations,
          typename InLeftPads,
          typename InRightPads>
void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in,
                                   const Tensor<TWei>& wei,
                                   Tensor<TOut>& out,
                                   const ConvStrides& conv_strides,
                                   const ConvDilations& conv_dilations,
                                   const InLeftPads& in_left_pads,
                                   const InRightPads&)
{
    using namespace ck;

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    const auto Di     = in.mDesc.GetLengths()[1];
    const auto Hi     = in.mDesc.GetLengths()[2];
    const auto Wi     = in.mDesc.GetLengths()[3];
    const auto Z      = wei.mDesc.GetLengths()[1];
    const auto Y      = wei.mDesc.GetLengths()[2];
    const auto X      = wei.mDesc.GetLengths()[3];
    const auto C      = wei.mDesc.GetLengths()[4];

    auto f_ndhwc = [&](auto n, auto do__, auto ho_, auto wo_, auto k) {
        // do__ must be converted to signed integer, otherwise zmin might be wrong in cases
        // negative values.
        const int do_ = static_cast<int>(do__);
        const int ho  = static_cast<int>(ho_);
        const int wo  = static_cast<int>(wo_);
        const int zmin =
            std::max(0,
                     (in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) /
                         conv_dilations[I0]);
        const int ymin =
            std::max(0,
                     (in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) /
                         conv_dilations[I1]);
        const int xmin =
            std::max(0,
                     (in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) /
                         conv_dilations[I2]);
        const int zmax =
            std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]);
        const int ymax =
            std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]);
        const int xmax =
            std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]);
        const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0];
        const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1];
        const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2];

        double v = 0;

        const TIn* in_n   = in.mData.data() + n * Di * Hi * Wi * C;
        const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C;

        int di = di_min;
        for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0])
        {
            const TIn* in_n_di  = in_n + di * Hi * Wi * C;
            const TWei* wei_k_z = wei_k + z * Y * X * C;
            int hi              = hi_min;

            for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1])
            {
                const TIn* in_n_di_hi = in_n_di + hi * Wi * C;
                const TWei* wei_k_z_y = wei_k_z + y * X * C;
                int wi                = wi_min;

                for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2])
                {
                    const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C;
                    const TWei* wei_k_z_y_x  = wei_k_z_y + x * C;

                    for(int c = 0; c < C; ++c)
                    {
                        v += static_cast<const double>(in_n_di_hi_wi[c]) *
                             static_cast<const double>(wei_k_z_y_x[c]);
                    }
                }
            }
        }

        out(n, do_, ho, wo, k) = v;
    };

    make_ParallelTensorFunctor(f_ndhwc,
                               out.mDesc.GetLengths()[0],
                               out.mDesc.GetLengths()[1],
                               out.mDesc.GetLengths()[2],
                               out.mDesc.GetLengths()[3],
                               out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4);
}