"src/vscode:/vscode.git/clone" did not exist on "ff40d99cb2015a211c49183c8bf7c5ed7e803cea"
op_shape_test.cpp 4.94 KB
Newer Older
1
2
3
4
5
6
7
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <sstream>
#include "test.hpp"

Paul's avatar
Paul committed
8
template <class... Ts>
Paul's avatar
Paul committed
9
void expect_shape(const migraph::shape& expected, const migraph::operation& op, Ts... xs)
10
11
12
{
    migraph::program p;
    std::vector<migraph::shape> shapes{xs...};
Paul's avatar
Paul committed
13
    std::vector<migraph::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
14
15
    std::transform(
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
16
    p.add_instruction(op, args);
Paul's avatar
Paul committed
17
18
    if(p.get_shape() != expected)
    {
19
20
        std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
        std::cout << expected << " != " << p.get_shape() << std::endl;
Paul's avatar
Paul committed
21
        for(auto&& s : shapes)
22
23
24
25
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
26
template <class... Ts>
Paul's avatar
Paul committed
27
void throws_shape(const migraph::operation& op, Ts... xs)
28
29
30
{
    migraph::program p;
    std::vector<migraph::shape> shapes{xs...};
Paul's avatar
Paul committed
31
    std::vector<migraph::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
32
33
    std::transform(
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
34
    bool thrown = test::throws([&] { p.add_instruction(op, args); });
Paul's avatar
Paul committed
35
36
    if(not thrown)
    {
37
        std::cout << "FAILED: No error found for " << op.name() << ": ";
Paul's avatar
Paul committed
38
        for(auto&& s : shapes)
39
40
41
42
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
43
44
45
46
template <class...>
struct always_false : std::false_type
{
};
47

Paul's avatar
Paul committed
48
template <class... Ts>
Paul's avatar
Paul committed
49
void throws_shape(const migraph::shape&, Ts...)
50
{
Paul's avatar
Paul committed
51
52
    static_assert(always_false<Ts...>{},
                  "An expected shape should not be passed to throws_shape function");
53
54
}

Paul's avatar
Paul committed
55
void batch_norm_inference_shape()
56
57
58
59
{
    const size_t channels = 3;
    migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
    migraph::shape vars{migraph::shape::float_type, {channels}};
60
61
62
    expect_shape(s, migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars);
    throws_shape(migraph::op::batch_norm_inference{}, s);
    throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
63
64
}

Paul's avatar
Paul committed
65
void convolution_shape()
66
67
68
69
{
    migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
    migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
    migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}};
70
71
    expect_shape(output, migraph::op::convolution{}, input, weights);
    throws_shape(migraph::op::convolution{}, input);
72
73
74

    migraph::shape input2{migraph::shape::float_type, {3, 3}};
    migraph::shape weights2{migraph::shape::float_type, {3, 3}};
75
76
    throws_shape(migraph::op::convolution{}, input2, weights2);
    throws_shape(migraph::op::convolution{}, input2, weights);
77
78
79
80
81
82
}

void transpose_shape()
{
    migraph::shape input{migraph::shape::float_type, {2, 2}};
    migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}};
83
84
85
    expect_shape(input, migraph::op::transpose{{0, 1}}, input);
    expect_shape(output, migraph::op::transpose{{1, 0}}, input);
    throws_shape(migraph::op::transpose{{1, 2}}, input);
86
87
88
89
90
91
}

void contiguous_shape()
{
    migraph::shape output{migraph::shape::float_type, {2, 2}};
    migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
92
93
    expect_shape(output, migraph::op::contiguous{}, input);
    throws_shape(migraph::op::contiguous{}, input, input);
Paul's avatar
Paul committed
94

95
    migraph::shape single{migraph::shape::float_type, {2}};
96
    throws_shape(migraph::op::contiguous{}, single);
97
98
99
100
101
}

void reshape_shape()
{
    migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}};
Paul's avatar
Paul committed
102
103
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
104
105
106
107
    {
        std::vector<std::size_t> lens(new_shape.size());
        std::copy(new_shape.begin(), new_shape.end(), lens.begin());
        migraph::shape output{migraph::shape::float_type, lens};
108
        expect_shape(output, migraph::op::reshape{new_shape}, input);
109
110
    }

Paul's avatar
Paul committed
111
    for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
112
    {
113
        throws_shape(migraph::op::reshape{new_shape}, input);
114
115
116
117
118
119
    }
}

void flatten_shape()
{
    migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
Paul's avatar
Paul committed
120
    expect_shape(
121
        migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, migraph::op::flatten{0}, input);
Paul's avatar
Paul committed
122
    expect_shape(
123
        migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input);
Paul's avatar
Paul committed
124
    expect_shape(
125
        migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
Paul's avatar
Paul committed
126
    expect_shape(
127
        migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
Paul's avatar
Paul committed
128
    expect_shape(
129
130
        migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}}, migraph::op::flatten{4}, input);
    throws_shape(migraph::op::flatten{5}, input);
131
132
}

Paul's avatar
Paul committed
133
int main()
134
135
136
137
138
139
140
141
{
    batch_norm_inference_shape();
    convolution_shape();
    transpose_shape();
    contiguous_shape();
    reshape_shape();
    flatten_shape();
}