host_conv.hpp 1.94 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "host_tensor.hpp"
3
#include "conv_common.hpp"
4

zjing14's avatar
zjing14 committed
5
6
7
8
9
10
11
template <typename TIn,
          typename TWei,
          typename TOut,
          typename ConvStrides,
          typename ConvDilations,
          typename InLeftPads,
          typename InRightPads>
12
13
14
15
16
17
18
void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
                              const Tensor<TWei>& wei,
                              Tensor<TOut>& out,
                              const ConvStrides& conv_strides,
                              const ConvDilations& conv_dilations,
                              const InLeftPads& in_left_pads,
                              const InRightPads&)
19
{
20
21
    constexpr auto I0 = ck::Number<0>{};
    constexpr auto I1 = ck::Number<1>{};
22

23
    auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
24
        float v = 0;
25
        for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
26
        {
27
            for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
28
            {
29
30
                int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
                for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
31
                {
32
33
34
                    int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
                    if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
                       wi < in.mDesc.GetLengths()[3])
35
                    {
36
37
                        v += ck::type_convert<float>(in(n, c, hi, wi)) *
                             ck::type_convert<float>(wei(k, c, y, x));
38
39
40
41
                    }
                }
            }
        }
42
        out(n, k, ho, wo) = ck::type_convert<TOut>(v);
43
44
    };

45
46
47
48
49
    make_ParallelTensorFunctor(f_nchw,
                               out.mDesc.GetLengths()[0],
                               out.mDesc.GetLengths()[1],
                               out.mDesc.GetLengths()[2],
                               out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
50
}