eval_test.cpp 4.71 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
7
#include "test.hpp"

Paul's avatar
Paul committed
8
9
struct sum_op
{
Paul's avatar
Paul committed
10
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
11
12
    migraph::argument
    compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
13
    {
Paul's avatar
Paul committed
14
        migraph::argument result;
Paul's avatar
Paul committed
15
        if(args.size() != 2)
Paul's avatar
Paul committed
16
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
17
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
18
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
19
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
20
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
21
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
22
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
23
24

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
25
            args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
26
27
28
29
        });
        return result;
    }

Paul's avatar
Paul committed
30
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
31
    {
Paul's avatar
Paul committed
32
        if(inputs.size() != 2)
Paul's avatar
Paul committed
33
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
34
35
36
37
        return inputs.front();
    }
};

Paul's avatar
Paul committed
38
39
40
struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
41
42
    migraph::argument
    compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
43
    {
Paul's avatar
Paul committed
44
        migraph::argument result;
Paul's avatar
Paul committed
45
        if(args.size() != 2)
Paul's avatar
Paul committed
46
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
47
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
48
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
49
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
50
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
51
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
52
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
53
54

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
55
            args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
56
57
58
59
        });
        return result;
    }

Paul's avatar
Paul committed
60
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
61
62
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
63
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
64
65
66
67
        return inputs.front();
    }
};

Paul's avatar
Paul committed
68
69
struct id_target
{
Paul's avatar
Paul committed
70
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
71
72
73
74
    std::vector<migraph::pass> get_passes(migraph::context&) const
    {
        return {};
    }
Paul's avatar
Paul committed
75
    migraph::context get_context() const { return {}; }
Paul's avatar
Paul committed
76
77
};

Paul's avatar
Paul committed
78
void literal_test1()
Paul's avatar
Paul committed
79
{
Paul's avatar
Paul committed
80
    migraph::program p;
Paul's avatar
Paul committed
81
82
83

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
84
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
85
    auto result = p.eval({});
Paul's avatar
Paul committed
86
87
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
88
89
}

Paul's avatar
Paul committed
90
91
void literal_test2()
{
Paul's avatar
Paul committed
92
    migraph::program p;
Paul's avatar
Paul committed
93

Paul's avatar
Paul committed
94
95
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
96
97
98
99
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
100
101
    EXPECT(result == migraph::literal{5});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
102
103
}

Paul's avatar
Paul committed
104
105
void print_test()
{
Paul's avatar
Paul committed
106
    migraph::program p;
Paul's avatar
Paul committed
107

Paul's avatar
Paul committed
108
    auto x   = p.add_parameter("x", {migraph::shape::int64_type});
Paul's avatar
Paul committed
109
110
111
112
113
114
115
116
117
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
118
119
void param_test()
{
Paul's avatar
Paul committed
120
    migraph::program p;
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
123
    auto x = p.add_parameter("x", {migraph::shape::int64_type});
    auto y = p.add_parameter("y", {migraph::shape::int64_type});
Paul's avatar
Paul committed
124

Paul's avatar
Paul committed
125
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
126
127
    auto result = p.eval(
        {{"x", migraph::literal{1}.get_argument()}, {"y", migraph::literal{2}.get_argument()}});
Paul's avatar
Paul committed
128
129
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
130
131
}

Paul's avatar
Paul committed
132
133
void replace_test()
{
Paul's avatar
Paul committed
134
    migraph::program p;
Paul's avatar
Paul committed
135
136
137
138
139
140
141

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);

    auto result = p.eval({});
Paul's avatar
Paul committed
142
143
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
144
145
}

Paul's avatar
Paul committed
146
147
void insert_replace_test()
{
Paul's avatar
Paul committed
148
    migraph::program p;
Paul's avatar
Paul committed
149

Paul's avatar
Paul committed
150
151
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
152
153
154
155
156
157
158
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
159
160
    EXPECT(result == migraph::literal{4});
    EXPECT(result != migraph::literal{5});
Paul's avatar
Paul committed
161
162
}

Paul's avatar
Paul committed
163
164
void target_test()
{
Paul's avatar
Paul committed
165
    migraph::program p;
Paul's avatar
Paul committed
166
167
168
169
170
171

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
172
173
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
174
175
}

Paul's avatar
Paul committed
176
177
int main()
{
Paul's avatar
Paul committed
178
179
    literal_test1();
    literal_test2();
Paul's avatar
Paul committed
180
    print_test();
Paul's avatar
Paul committed
181
    param_test();
Paul's avatar
Paul committed
182
    replace_test();
Paul's avatar
Paul committed
183
    insert_replace_test();
Paul's avatar
Paul committed
184
    target_test();
Paul's avatar
Paul committed
185
}