eval_test.cpp 4.65 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
7
#include "test.hpp"

Paul's avatar
Paul committed
8
9
struct sum_op
{
Paul's avatar
Paul committed
10
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
11
    migraph::argument compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
12
    {
Paul's avatar
Paul committed
13
        migraph::argument result;
Paul's avatar
Paul committed
14
        if(args.size() != 2)
Paul's avatar
Paul committed
15
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
16
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
17
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
18
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
19
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
20
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
21
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
22
23

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
24
            args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
25
26
27
28
        });
        return result;
    }

Paul's avatar
Paul committed
29
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
30
    {
Paul's avatar
Paul committed
31
        if(inputs.size() != 2)
Paul's avatar
Paul committed
32
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
33
34
35
36
        return inputs.front();
    }
};

Paul's avatar
Paul committed
37
38
39
struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
40
    migraph::argument compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
41
    {
Paul's avatar
Paul committed
42
        migraph::argument result;
Paul's avatar
Paul committed
43
        if(args.size() != 2)
Paul's avatar
Paul committed
44
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
45
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
46
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
47
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
48
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
49
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
50
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
51
52

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
53
            args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
54
55
56
57
        });
        return result;
    }

Paul's avatar
Paul committed
58
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
59
60
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
61
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
62
63
64
65
        return inputs.front();
    }
};

Paul's avatar
Paul committed
66
67
struct id_target
{
Paul's avatar
Paul committed
68
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
69
70
    void apply(migraph::program&) const {}
    migraph::context get_context() const { return {}; }
Paul's avatar
Paul committed
71
72
};

Paul's avatar
Paul committed
73
void literal_test1()
Paul's avatar
Paul committed
74
{
Paul's avatar
Paul committed
75
    migraph::program p;
Paul's avatar
Paul committed
76
77
78

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
79
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
80
    auto result = p.eval({});
Paul's avatar
Paul committed
81
82
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
83
84
}

Paul's avatar
Paul committed
85
86
void literal_test2()
{
Paul's avatar
Paul committed
87
    migraph::program p;
Paul's avatar
Paul committed
88

Paul's avatar
Paul committed
89
90
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
91
92
93
94
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
95
96
    EXPECT(result == migraph::literal{5});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
97
98
}

Paul's avatar
Paul committed
99
100
void print_test()
{
Paul's avatar
Paul committed
101
    migraph::program p;
Paul's avatar
Paul committed
102

Paul's avatar
Paul committed
103
    auto x   = p.add_parameter("x", {migraph::shape::int64_type});
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
112
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
113
114
void param_test()
{
Paul's avatar
Paul committed
115
    migraph::program p;
Paul's avatar
Paul committed
116

Paul's avatar
Paul committed
117
118
    auto x = p.add_parameter("x", {migraph::shape::int64_type});
    auto y = p.add_parameter("y", {migraph::shape::int64_type});
Paul's avatar
Paul committed
119

Paul's avatar
Paul committed
120
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
121
    auto result =
Paul's avatar
Paul committed
122
123
124
        p.eval({{"x", migraph::literal{1}.get_argument()}, {"y", migraph::literal{2}.get_argument()}});
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
125
126
}

Paul's avatar
Paul committed
127
128
void replace_test()
{
Paul's avatar
Paul committed
129
    migraph::program p;
Paul's avatar
Paul committed
130
131
132
133
134
135
136

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);

    auto result = p.eval({});
Paul's avatar
Paul committed
137
138
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
139
140
}

Paul's avatar
Paul committed
141
142
void insert_replace_test()
{
Paul's avatar
Paul committed
143
    migraph::program p;
Paul's avatar
Paul committed
144

Paul's avatar
Paul committed
145
146
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
147
148
149
150
151
152
153
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
154
155
    EXPECT(result == migraph::literal{4});
    EXPECT(result != migraph::literal{5});
Paul's avatar
Paul committed
156
157
}

Paul's avatar
Paul committed
158
159
void target_test()
{
Paul's avatar
Paul committed
160
    migraph::program p;
Paul's avatar
Paul committed
161
162
163
164
165
166

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
167
168
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
169
170
}

Paul's avatar
Paul committed
171
172
int main()
{
Paul's avatar
Paul committed
173
174
    literal_test1();
    literal_test2();
Paul's avatar
Paul committed
175
    print_test();
Paul's avatar
Paul committed
176
    param_test();
Paul's avatar
Paul committed
177
    replace_test();
Paul's avatar
Paul committed
178
    insert_replace_test();
Paul's avatar
Paul committed
179
    target_test();
Paul's avatar
Paul committed
180
}