host_col2im.hpp 2.3 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#pragma once
#include "host_tensor.hpp"

template <typename T,
          typename FilterSizes,
          typename OutputSizes,
          typename ConvStrides,
          typename ConvDilations,
          typename LeftPads,
          typename RightPads>
void host_col2im(const Tensor<T>& in_eb,
                 Tensor<T>& in_nchw,
                 FilterSizes,
                 OutputSizes,
                 ConvStrides,
                 ConvDilations,
                 LeftPads,
                 RightPads)
{
    using namespace ck;

    int N  = in_nchw.mDesc.GetLengths()[0];
    int C  = in_nchw.mDesc.GetLengths()[1];
Chao Liu's avatar
Chao Liu committed
24
25
    int Hi = in_nchw.mDesc.GetLengths()[2];
    int Wi = in_nchw.mDesc.GetLengths()[3];
Chao Liu's avatar
Chao Liu committed
26
27
28
29

    int Y = FilterSizes{}[0];
    int X = FilterSizes{}[1];

Chao Liu's avatar
Chao Liu committed
30
31
    int Ho = OutputSizes{}[0];
    int Wo = OutputSizes{}[1];
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
38
39

    auto f = [&](auto n, auto c, auto hi, auto wi) {
        double v = 0;

        for(int y = 0; y < Y; ++y)
        {
            int h_tmp = hi + LeftPads{}[0] - y * ConvDilations{}[0];

Chao Liu's avatar
Chao Liu committed
40
            if(h_tmp % ConvStrides{}[0] == 0)
Chao Liu's avatar
Chao Liu committed
41
42
43
            {
                int ho = h_tmp / ConvStrides{}[0];

Chao Liu's avatar
Chao Liu committed
44
                if(ho >= 0 && ho < Ho)
Chao Liu's avatar
Chao Liu committed
45
                {
Chao Liu's avatar
Chao Liu committed
46
                    for(int x = 0; x < X; ++x)
Chao Liu's avatar
Chao Liu committed
47
                    {
Chao Liu's avatar
Chao Liu committed
48
49
50
51
52
                        int w_tmp = wi + LeftPads{}[1] - x * ConvDilations{}[1];

                        if(w_tmp % ConvStrides{}[1] == 0)
                        {
                            int wo = w_tmp / ConvStrides{}[1];
Chao Liu's avatar
Chao Liu committed
53

Chao Liu's avatar
Chao Liu committed
54
55
56
57
                            if(wo >= 0 && wo < Wo && w_tmp % ConvStrides{}[1] == 0)
                            {
                                int e = c * (Y * X) + y * X + x;
                                int b = n * (Ho * Wo) + ho * Wo + wo;
Chao Liu's avatar
Chao Liu committed
58

Chao Liu's avatar
Chao Liu committed
59
60
61
                                v += in_eb(e, b);
                            }
                        }
Chao Liu's avatar
Chao Liu committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                    }
                }
            }
        }

        in_nchw(n, c, hi, wi) = v;
    };

    auto f_par = make_ParallelTensorFunctor(f,
                                            in_nchw.mDesc.GetLengths()[0],
                                            in_nchw.mDesc.GetLengths()[1],
                                            in_nchw.mDesc.GetLengths()[2],
                                            in_nchw.mDesc.GetLengths()[3]);

    f_par(std::thread::hardware_concurrency());
}