host_conv_bwd_data.hpp 4.93 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "host_tensor.hpp"
Chao Liu's avatar
Chao Liu committed
3
4
5
6
7
8

template <typename TIn,
          typename TWei,
          typename TOut,
          typename ConvStrides,
          typename ConvDilations,
zjing14's avatar
zjing14 committed
9
10
11
12
13
14
15
16
17
18
          typename InLeftPads,
          typename InRightPads>
void host_direct_convolution_backward_data(Tensor<TIn>& in,
                                           const Tensor<TWei>& wei,
                                           const Tensor<TOut>& out,
                                           const ConvStrides& conv_strides,
                                           const ConvDilations& conv_dilations,
                                           const InLeftPads& in_left_pads,
                                           const InRightPads& in_right_pads,
                                           const ConvTensorLayout layout = ConvTensorLayout::NCHW)
Chao Liu's avatar
Chao Liu committed
19
20
21
{
    using namespace ck;

zjing14's avatar
zjing14 committed
22
23
24
25
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
26

zjing14's avatar
zjing14 committed
27
28
29
30
31
    auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
        std::size_t N  = in.mDesc.GetLengths()[I0];
        std::size_t C  = in.mDesc.GetLengths()[I1];
        std::size_t Hi = in.mDesc.GetLengths()[I2];
        std::size_t Wi = in.mDesc.GetLengths()[I3];
Chao Liu's avatar
Chao Liu committed
32

zjing14's avatar
zjing14 committed
33
34
35
36
37
38
        std::size_t K = wei.mDesc.GetLengths()[I0];
        std::size_t Y = wei.mDesc.GetLengths()[I2];
        std::size_t X = wei.mDesc.GetLengths()[I3];

        std::size_t Ho = out.mDesc.GetLengths()[I2];
        std::size_t Wo = out.mDesc.GetLengths()[I3];
Chao Liu's avatar
Chao Liu committed
39
40
41
42
43

        double v = 0;

        for(int y = 0; y < Y; ++y)
        {
zjing14's avatar
zjing14 committed
44
            int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
Chao Liu's avatar
Chao Liu committed
45

zjing14's avatar
zjing14 committed
46
            if(h_tmp % conv_strides[I0] == 0)
Chao Liu's avatar
Chao Liu committed
47
            {
zjing14's avatar
zjing14 committed
48
                int ho = h_tmp / conv_strides[I0];
Chao Liu's avatar
Chao Liu committed
49

zjing14's avatar
zjing14 committed
50
                if(ho >= 0 && ho < Ho)
Chao Liu's avatar
Chao Liu committed
51
52
53
                {
                    for(int x = 0; x < X; ++x)
                    {
zjing14's avatar
zjing14 committed
54
                        int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
Chao Liu's avatar
Chao Liu committed
55

zjing14's avatar
zjing14 committed
56
                        if(w_tmp % conv_strides[I1] == 0)
Chao Liu's avatar
Chao Liu committed
57
                        {
zjing14's avatar
zjing14 committed
58
                            int wo = w_tmp / conv_strides[I1];
Chao Liu's avatar
Chao Liu committed
59

zjing14's avatar
zjing14 committed
60
                            if(wo >= 0 && wo < Wo)
Chao Liu's avatar
Chao Liu committed
61
62
63
                            {
                                for(int k = 0; k < K; ++k)
                                {
zjing14's avatar
zjing14 committed
64
                                    v += out(n, k, ho, wo) * wei(k, c, y, x);
Chao Liu's avatar
Chao Liu committed
65
66
67
68
69
70
71
72
                                }
                            }
                        }
                    }
                }
            }
        }

zjing14's avatar
zjing14 committed
73
        in(n, c, hi, wi) = v;
Chao Liu's avatar
Chao Liu committed
74
75
    };

zjing14's avatar
zjing14 committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) {
        std::size_t N  = in.mDesc.GetLengths()[I0];
        std::size_t Hi = in.mDesc.GetLengths()[I1];
        std::size_t Wi = in.mDesc.GetLengths()[I2];
        std::size_t C  = in.mDesc.GetLengths()[I3];

        std::size_t K = wei.mDesc.GetLengths()[I0];
        std::size_t Y = wei.mDesc.GetLengths()[I1];
        std::size_t X = wei.mDesc.GetLengths()[I2];

        std::size_t Ho = out.mDesc.GetLengths()[I1];
        std::size_t Wo = out.mDesc.GetLengths()[I2];

        double v = 0;

        for(int y = 0; y < Y; ++y)
        {
            int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];

            if(h_tmp % conv_strides[I0] == 0)
            {
                int ho = h_tmp / conv_strides[I0];

                if(ho >= 0 && ho < Ho)
                {
                    for(int x = 0; x < X; ++x)
                    {
                        int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];

                        if(w_tmp % conv_strides[I1] == 0)
                        {
                            int wo = w_tmp / conv_strides[I1];

                            if(wo >= 0 && wo < Wo)
                            {
                                for(int k = 0; k < K; ++k)
                                {
                                    v += out(n, ho, wo, k) * wei(k, y, x, c);
                                }
                            }
                        }
                    }
                }
            }
        }

        in(n, hi, wi, c) = v;
    };
Chao Liu's avatar
Chao Liu committed
124

zjing14's avatar
zjing14 committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    switch(layout)
    {
    case ConvTensorLayout::NCHW:
        make_ParallelTensorFunctor(f_nchw,
                                   in.mDesc.GetLengths()[0],
                                   in.mDesc.GetLengths()[1],
                                   in.mDesc.GetLengths()[2],
                                   in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
        break;
    case ConvTensorLayout::NHWC:
        make_ParallelTensorFunctor(f_nhwc,
                                   in.mDesc.GetLengths()[0],
                                   in.mDesc.GetLengths()[1],
                                   in.mDesc.GetLengths()[2],
                                   in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
        break;
    default: throw std::runtime_error("wrong! not supported layout");
    }
Chao Liu's avatar
Chao Liu committed
143
}