"build/cmake-hip.sh" did not exist on "121693b3d3b3148010f0756c5ab4741476620aba"
op_shape_test.cpp 6.79 KB
Newer Older
1
2
3
4
5
6
7
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <sstream>
#include "test.hpp"

Paul's avatar
Paul committed
8
template <class... Ts>
Paul's avatar
Paul committed
9
void expect_shape(const migraph::shape& expected, const migraph::operation& op, Ts... xs)
10
11
12
{
    migraph::program p;
    std::vector<migraph::shape> shapes{xs...};
Paul's avatar
Paul committed
13
    std::vector<migraph::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
14
15
    std::transform(
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
16
    p.add_instruction(op, args);
Paul's avatar
Paul committed
17
18
    if(p.get_shape() != expected)
    {
19
20
        std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
        std::cout << expected << " != " << p.get_shape() << std::endl;
Paul's avatar
Paul committed
21
        for(auto&& s : shapes)
22
23
24
25
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
26
template <class... Ts>
Paul's avatar
Paul committed
27
void throws_shape(const migraph::operation& op, Ts... xs)
28
29
30
{
    migraph::program p;
    std::vector<migraph::shape> shapes{xs...};
Paul's avatar
Paul committed
31
    std::vector<migraph::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
32
33
    std::transform(
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
34
    bool thrown = test::throws([&] { p.add_instruction(op, args); });
Paul's avatar
Paul committed
35
36
    if(not thrown)
    {
37
        std::cout << "FAILED: No error found for " << op.name() << ": ";
Paul's avatar
Paul committed
38
        for(auto&& s : shapes)
39
40
41
42
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
43
44
45
46
template <class...>
struct always_false : std::false_type
{
};
47

Paul's avatar
Paul committed
48
template <class... Ts>
Paul's avatar
Paul committed
49
void throws_shape(const migraph::shape&, Ts...)
50
{
Paul's avatar
Paul committed
51
52
    static_assert(always_false<Ts...>{},
                  "An expected shape should not be passed to throws_shape function");
53
54
}

Paul's avatar
Paul committed
55
void batch_norm_inference_shape()
56
57
58
59
{
    const size_t channels = 3;
    migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
    migraph::shape vars{migraph::shape::float_type, {channels}};
60
61
62
    expect_shape(s, migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars);
    throws_shape(migraph::op::batch_norm_inference{}, s);
    throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
63
64
}

Paul's avatar
Paul committed
65
void convolution_shape()
66
67
68
69
{
    migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
    migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
    migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}};
70
71
    expect_shape(output, migraph::op::convolution{}, input, weights);
    throws_shape(migraph::op::convolution{}, input);
72
73
74

    migraph::shape input2{migraph::shape::float_type, {3, 3}};
    migraph::shape weights2{migraph::shape::float_type, {3, 3}};
75
76
    throws_shape(migraph::op::convolution{}, input2, weights2);
    throws_shape(migraph::op::convolution{}, input2, weights);
77
78
79
80
81
82
}

void transpose_shape()
{
    migraph::shape input{migraph::shape::float_type, {2, 2}};
    migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}};
83
84
85
    expect_shape(input, migraph::op::transpose{{0, 1}}, input);
    expect_shape(output, migraph::op::transpose{{1, 0}}, input);
    throws_shape(migraph::op::transpose{{1, 2}}, input);
86
87
88
89
90
91
}

void contiguous_shape()
{
    migraph::shape output{migraph::shape::float_type, {2, 2}};
    migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
92
93
    expect_shape(output, migraph::op::contiguous{}, input);
    throws_shape(migraph::op::contiguous{}, input, input);
Paul's avatar
Paul committed
94

95
    migraph::shape single{migraph::shape::float_type, {2}};
Khalique's avatar
Khalique committed
96
    expect_shape(single, migraph::op::contiguous{}, single);
97
98
99
100
101
}

void reshape_shape()
{
    migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}};
Paul's avatar
Paul committed
102
103
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
104
105
106
107
    {
        std::vector<std::size_t> lens(new_shape.size());
        std::copy(new_shape.begin(), new_shape.end(), lens.begin());
        migraph::shape output{migraph::shape::float_type, lens};
108
        expect_shape(output, migraph::op::reshape{new_shape}, input);
109
110
    }

Paul's avatar
Paul committed
111
    for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
112
    {
113
        throws_shape(migraph::op::reshape{new_shape}, input);
114
115
116
117
118
119
    }
}

void flatten_shape()
{
    migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
Scott Thornton's avatar
Scott Thornton committed
120
121
122
    expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}},
                 migraph::op::flatten{0},
                 input);
Paul's avatar
Paul committed
123
    expect_shape(
124
        migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input);
Paul's avatar
Paul committed
125
    expect_shape(
126
        migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
Paul's avatar
Paul committed
127
    expect_shape(
128
        migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
Scott Thornton's avatar
Scott Thornton committed
129
130
131
    expect_shape(migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}},
                 migraph::op::flatten{4},
                 input);
132
    throws_shape(migraph::op::flatten{5}, input);
133
134
}

Scott Thornton's avatar
Scott Thornton committed
135
136
137
138
void slice_shape()
{
    migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
    expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
Paul's avatar
Paul committed
139
                 migraph::op::slice{{2}, {1}, {3}},
Scott Thornton's avatar
Scott Thornton committed
140
141
                 input);
    expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
Paul's avatar
Paul committed
142
                 migraph::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
Scott Thornton's avatar
Scott Thornton committed
143
144
                 input);
    expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
Paul's avatar
Paul committed
145
                 migraph::op::slice{{2}, {2}, {10}},
Scott Thornton's avatar
Scott Thornton committed
146
147
                 input);
}
Scott Thornton's avatar
Scott Thornton committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

void multibroadcast_shape()
{
    {
        std::vector<std::size_t> lens{4,2,5,3};
        migraph::shape input{migraph::shape::float_type, {2,1,3}};
        expect_shape(migraph::shape{migraph::shape::float_type, lens, {0,3,0,1}},
            migraph::op::multibroadcast{lens}, input);
    }
    {
        std::vector<std::size_t> lens{4,2,5,3};
        migraph::shape input{migraph::shape::float_type, {2,1,1}};
        expect_shape(migraph::shape{migraph::shape::float_type, lens, {0,1,0,0}},
            migraph::op::multibroadcast{lens}, input);
    }
    {
        std::vector<std::size_t> lens{4,1,1,3};
        migraph::shape input{migraph::shape::float_type, {4,1,1,1}};
        expect_shape(migraph::shape{migraph::shape::float_type, lens, {1,1,1,0}},
            migraph::op::multibroadcast{lens}, input);
    }
    {
        std::vector<std::size_t> lens{4,1,3};
        migraph::shape input{migraph::shape::float_type, {4,1,1,1}};
        throws_shape(migraph::op::multibroadcast{lens}, input);
    }
    {
        std::vector<std::size_t> lens{4,1,3};
        migraph::shape input{migraph::shape::float_type, {}};
        throws_shape(migraph::op::multibroadcast{lens}, input);
    }
}

Paul's avatar
Paul committed
181
int main()
182
{
Scott Thornton's avatar
Scott Thornton committed
183
    multibroadcast_shape();
184
185
186
187
188
189
    batch_norm_inference_shape();
    convolution_shape();
    transpose_shape();
    contiguous_shape();
    reshape_shape();
    flatten_shape();
Scott Thornton's avatar
Scott Thornton committed
190
    slice_shape();
191
}