conv_common.hpp 5.19 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
#ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP
3

4
#include "ConstantTensorDescriptor_deprecated.hpp"
Chao Liu's avatar
Chao Liu committed
5
#include "tensor_descriptor.hpp"
Chao Liu's avatar
Chao Liu committed
6

Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
12
13
14
template <class InDesc,
          class WeiDesc,
          class ConvStrides,
          class ConvDilations,
          class LowerPads,
          class UpperPads>
constexpr auto get_convolution_output_default_4d_tensor_descriptor_deprecated(
    InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
Chao Liu's avatar
Chao Liu committed
15
{
Chao Liu's avatar
Chao Liu committed
16
17
    using namespace ck;

Chao Liu's avatar
Chao Liu committed
18
19
20
21
22
23
24
25
    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

Chao Liu's avatar
Chao Liu committed
26
27
    static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
    static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
Chao Liu's avatar
Chao Liu committed
28
29
30
    static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
                  "input & weight dimension not consistent");

Chao Liu's avatar
Chao Liu committed
31
32
33
34
35
36
37
38
39
40
    constexpr index_t N  = in_desc.GetLength(I0);
    constexpr index_t Hi = in_desc.GetLength(I2);
    constexpr index_t Wi = in_desc.GetLength(I3);

    constexpr index_t K = wei_desc.GetLength(I0);
    constexpr index_t Y = wei_desc.GetLength(I2);
    constexpr index_t X = wei_desc.GetLength(I3);

    constexpr index_t HPadLow = LowerPads{}.Get(I0);
    constexpr index_t WPadLow = LowerPads{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
41

Chao Liu's avatar
Chao Liu committed
42
43
    constexpr index_t HPadUp = UpperPads{}.Get(I0);
    constexpr index_t WPadUp = UpperPads{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
44

Chao Liu's avatar
Chao Liu committed
45
46
47
48
49
    constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
    constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;

    constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
    constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
50

Chao Liu's avatar
Chao Liu committed
51
    return make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
52
53
}

54
55
56
57
58
59
template <class InDesc,
          class WeiDesc,
          class ConvStrides,
          class ConvDilations,
          class LowerPads,
          class UpperPads>
Chao Liu's avatar
Chao Liu committed
60
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
61
    InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
62
{
Chao Liu's avatar
Chao Liu committed
63
64
    using namespace ck;

65
66
67
68
69
70
71
72
    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

Chao Liu's avatar
Chao Liu committed
73
74
    static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
    static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
75
76
77
    static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
                  "input & weight dimension not consistent");

78
79
80
    constexpr index_t N  = in_desc.GetLength(I0);
    constexpr index_t Hi = in_desc.GetLength(I2);
    constexpr index_t Wi = in_desc.GetLength(I3);
81

82
83
84
    constexpr index_t K = wei_desc.GetLength(I0);
    constexpr index_t Y = wei_desc.GetLength(I2);
    constexpr index_t X = wei_desc.GetLength(I3);
85

86
87
    constexpr index_t HPadLow = LowerPads{}.Get(I0);
    constexpr index_t WPadLow = LowerPads{}.Get(I1);
88

89
90
    constexpr index_t HPadUp = UpperPads{}.Get(I0);
    constexpr index_t WPadUp = UpperPads{}.Get(I1);
91

92
93
    constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
    constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
Chao Liu's avatar
Chao Liu committed
94

95
96
97
    constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
    constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;

Chao Liu's avatar
Chao Liu committed
98
    return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
Chao Liu's avatar
Chao Liu committed
99
}
Chao Liu's avatar
Chao Liu committed
100
101

template <class InDesc, class WeiDesc, class OutDesc>
Chao Liu's avatar
Chao Liu committed
102
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
Chao Liu's avatar
Chao Liu committed
103
{
Chao Liu's avatar
Chao Liu committed
104
105
    using namespace ck;

Chao Liu's avatar
Chao Liu committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr index_t N  = out_desc.GetLength(I0);
    constexpr index_t K  = out_desc.GetLength(I1);
    constexpr index_t Ho = out_desc.GetLength(I2);
    constexpr index_t Wo = out_desc.GetLength(I3);

    constexpr index_t C = wei_desc.GetLength(I1);
    constexpr index_t Y = wei_desc.GetLength(I2);
    constexpr index_t X = wei_desc.GetLength(I3);

    return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
}
Chao Liu's avatar
Chao Liu committed
125
126
127
128

template <class Float, class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc, OutDesc)
{
Chao Liu's avatar
Chao Liu committed
129
130
    using namespace ck;

Chao Liu's avatar
Chao Liu committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr index_t N  = out_desc.GetLength(I0);
    constexpr index_t K  = out_desc.GetLength(I1);
    constexpr index_t Ho = out_desc.GetLength(I2);
    constexpr index_t Wo = out_desc.GetLength(I3);

    constexpr index_t C = wei_desc.GetLength(I1);
    constexpr index_t Y = wei_desc.GetLength(I2);
    constexpr index_t X = wei_desc.GetLength(I3);

148
149
    return sizeof(Float) *
           (InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace());
Chao Liu's avatar
Chao Liu committed
150
}
151
152

#endif