conv_common.hip.hpp 3.4 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
Chao Liu's avatar
Chao Liu committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

// this is ugly, only for 4d
template <class InDesc, class WeiDesc>
__host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc,
                                                                                       WeiDesc)
{
    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    static_assert(in_desc.GetDimension() == 4, "input nDim is not 4");
    static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4");
    static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
                  "input & weight dimension not consistent");

    constexpr auto N  = in_desc.GetLength(I0);
    constexpr auto HI = in_desc.GetLength(I2);
    constexpr auto WI = in_desc.GetLength(I3);

    constexpr auto K = wei_desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
27
28
    constexpr auto Y = wei_desc.GetLength(I2);
    constexpr auto X = wei_desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
31
    constexpr auto HO = HI + 1 - Y;
    constexpr auto WO = WI + 1 - X;
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{});
}

template <class InDesc, class WeiDesc, class LowerPads, class UpperPads>
__host__ __device__ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
    InDesc, WeiDesc, LowerPads, UpperPads)
{
    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    static_assert(in_desc.GetDimension() == 4, "input nDim is not 4");
    static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4");
    static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
                  "input & weight dimension not consistent");

    constexpr auto N  = in_desc.GetLength(I0);
    constexpr auto HI = in_desc.GetLength(I2);
    constexpr auto WI = in_desc.GetLength(I3);

    constexpr auto K = wei_desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
58
59
    constexpr auto Y = wei_desc.GetLength(I2);
    constexpr auto X = wei_desc.GetLength(I3);
60
61
62
63
64
65
66

    constexpr auto HPadLow = LowerPads{}.Get(I0);
    constexpr auto WPadLow = LowerPads{}.Get(I1);

    constexpr auto HPadUp = UpperPads{}.Get(I0);
    constexpr auto WPadUp = UpperPads{}.Get(I1);

Chao Liu's avatar
Chao Liu committed
67
68
    constexpr auto HO = HI + HPadLow + HPadUp + 1 - Y;
    constexpr auto WO = WI + WPadLow + WPadUp + 1 - X;
Chao Liu's avatar
Chao Liu committed
69
70
71

    return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{});
}
Chao Liu's avatar
Chao Liu committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

template <class InDesc, class WeiDesc, class OutDesc>
__host__ __device__ constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
{
    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr index_t N  = out_desc.GetLength(I0);
    constexpr index_t K  = out_desc.GetLength(I1);
    constexpr index_t Ho = out_desc.GetLength(I2);
    constexpr index_t Wo = out_desc.GetLength(I3);

    constexpr index_t C = wei_desc.GetLength(I1);
    constexpr index_t Y = wei_desc.GetLength(I2);
    constexpr index_t X = wei_desc.GetLength(I3);

    return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
}