eval_test.cpp 4.7 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
7
#include "test.hpp"

Paul's avatar
Paul committed
8
9
struct sum_op
{
Paul's avatar
Paul committed
10
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
11
12
    migraph::argument
    compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
13
    {
Paul's avatar
Paul committed
14
        migraph::argument result;
Paul's avatar
Paul committed
15
        if(args.size() != 2)
Paul's avatar
Paul committed
16
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
17
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
18
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
19
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
20
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
21
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
22
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
23
24

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
25
            args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
26
27
28
29
        });
        return result;
    }

Paul's avatar
Paul committed
30
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
31
    {
Paul's avatar
Paul committed
32
        if(inputs.size() != 2)
Paul's avatar
Paul committed
33
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
34
35
36
37
        return inputs.front();
    }
};

Paul's avatar
Paul committed
38
39
40
struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
41
42
    migraph::argument
    compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
43
    {
Paul's avatar
Paul committed
44
        migraph::argument result;
Paul's avatar
Paul committed
45
        if(args.size() != 2)
Paul's avatar
Paul committed
46
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
47
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
48
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
49
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
50
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
51
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
52
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
53
54

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
55
            args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
56
57
58
59
        });
        return result;
    }

Paul's avatar
Paul committed
60
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
61
62
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
63
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
64
65
66
67
        return inputs.front();
    }
};

Paul's avatar
Paul committed
68
69
struct id_target
{
Paul's avatar
Paul committed
70
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
71
    std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
Paul's avatar
Paul committed
72
    migraph::context get_context() const { return {}; }
Paul's avatar
Paul committed
73
74
};

Paul's avatar
Paul committed
75
void literal_test1()
Paul's avatar
Paul committed
76
{
Paul's avatar
Paul committed
77
    migraph::program p;
Paul's avatar
Paul committed
78
79
80

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
81
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
82
    auto result = p.eval({});
Paul's avatar
Paul committed
83
84
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
85
86
}

Paul's avatar
Paul committed
87
88
void literal_test2()
{
Paul's avatar
Paul committed
89
    migraph::program p;
Paul's avatar
Paul committed
90

Paul's avatar
Paul committed
91
92
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
93
94
95
96
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
97
98
    EXPECT(result == migraph::literal{5});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
99
100
}

Paul's avatar
Paul committed
101
102
void print_test()
{
Paul's avatar
Paul committed
103
    migraph::program p;
Paul's avatar
Paul committed
104

Paul's avatar
Paul committed
105
    auto x   = p.add_parameter("x", {migraph::shape::int64_type});
Paul's avatar
Paul committed
106
107
108
109
110
111
112
113
114
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
115
116
void param_test()
{
Paul's avatar
Paul committed
117
    migraph::program p;
Paul's avatar
Paul committed
118

Paul's avatar
Paul committed
119
120
    auto x = p.add_parameter("x", {migraph::shape::int64_type});
    auto y = p.add_parameter("y", {migraph::shape::int64_type});
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
123
124
    auto result = p.eval(
        {{"x", migraph::literal{1}.get_argument()}, {"y", migraph::literal{2}.get_argument()}});
Paul's avatar
Paul committed
125
126
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
127
128
}

Paul's avatar
Paul committed
129
130
void replace_test()
{
Paul's avatar
Paul committed
131
    migraph::program p;
Paul's avatar
Paul committed
132
133
134
135
136
137
138

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);

    auto result = p.eval({});
Paul's avatar
Paul committed
139
140
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
141
142
}

Paul's avatar
Paul committed
143
144
void insert_replace_test()
{
Paul's avatar
Paul committed
145
    migraph::program p;
Paul's avatar
Paul committed
146

Paul's avatar
Paul committed
147
148
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
149
150
151
152
153
154
155
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
156
157
    EXPECT(result == migraph::literal{4});
    EXPECT(result != migraph::literal{5});
Paul's avatar
Paul committed
158
159
}

Paul's avatar
Paul committed
160
161
void target_test()
{
Paul's avatar
Paul committed
162
    migraph::program p;
Paul's avatar
Paul committed
163
164
165
166
167
168

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
169
170
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
171
172
}

Paul's avatar
Paul committed
173
174
int main()
{
Paul's avatar
Paul committed
175
176
    literal_test1();
    literal_test2();
Paul's avatar
Paul committed
177
    print_test();
Paul's avatar
Paul committed
178
    param_test();
Paul's avatar
Paul committed
179
    replace_test();
Paul's avatar
Paul committed
180
    insert_replace_test();
Paul's avatar
Paul committed
181
    target_test();
Paul's avatar
Paul committed
182
}