operation.cpp 6.21 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraphx/operation.hpp>
Paul's avatar
Paul committed
3
#include <migraphx/context.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <sstream>
#include <string>
#include "test.hpp"

struct simple_operation
{
Paul's avatar
Paul committed
10
    template <class T, class F>
Paul's avatar
Paul committed
11
12
    static auto reflect(T& x, F f)
    {
Paul's avatar
Paul committed
13
        return migraphx::pack(f(x.data, "data"));
Paul's avatar
Paul committed
14
    }
Paul's avatar
Paul committed
15
    int data = 1;
Paul's avatar
Paul committed
16
    std::string name() const { return "simple"; }
Paul's avatar
Paul committed
17
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
Paul's avatar
Paul committed
18
    {
Paul's avatar
Paul committed
19
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
20
    }
Paul's avatar
Paul committed
21
22
23
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
24
    {
Paul's avatar
Paul committed
25
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
26
    }
Paul's avatar
Paul committed
27
    friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
Paul's avatar
Paul committed
28
    {
Paul's avatar
Paul committed
29
        os << op.name() << "[" << op.data << "]";
Paul's avatar
Paul committed
30
31
        return os;
    }
Paul's avatar
Paul committed
32
33
};

Paul's avatar
Paul committed
34
35
36
struct simple_operation_no_print
{
    std::string name() const { return "simple"; }
Paul's avatar
Paul committed
37
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
Paul's avatar
Paul committed
38
    {
Paul's avatar
Paul committed
39
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
40
    }
Paul's avatar
Paul committed
41
42
43
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
44
    {
Paul's avatar
Paul committed
45
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
46
    }
Paul's avatar
Paul committed
47
48
};

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
struct compilable_op
{
    std::string name() const { return "compilable"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }

    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }

    migraphx::value
    compile(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&)
    {
        return {{"compiled", true}};
    }
};

Paul's avatar
Paul committed
76
TEST_CASE(operation_copy_test)
Paul's avatar
Paul committed
77
78
{
    simple_operation s{};
Paul's avatar
Paul committed
79
80
    migraphx::operation op1 = s;   // NOLINT
    migraphx::operation op2 = op1; // NOLINT
Paul's avatar
Paul committed
81
    // cppcheck-suppress duplicateExpression
Paul's avatar
Paul committed
82
    EXPECT(s == op1);
Paul's avatar
Paul committed
83
    // cppcheck-suppress duplicateExpression
Paul's avatar
Paul committed
84
85
86
    EXPECT(op2 == op1);
}

Paul Fultz II's avatar
Paul Fultz II committed
87
88
89
90
91
92
93
94
95
TEST_CASE(operation_copy_assign_test)
{
    simple_operation s{};
    migraphx::operation op;
    op = s;
    // cppcheck-suppress duplicateExpression
    EXPECT(s == op);
}

Paul's avatar
Paul committed
96
TEST_CASE(operation_equal_test)
Paul's avatar
Paul committed
97
98
{
    simple_operation s{};
Paul's avatar
Paul committed
99
    migraphx::operation op1 = s;
Paul's avatar
Paul committed
100
    s.data                  = 2;
Paul's avatar
Paul committed
101
102
    migraphx::operation op2 = op1; // NOLINT
    migraphx::operation op3 = s;   // NOLINT
Paul's avatar
Paul committed
103
104
105
106
107

    EXPECT(s != op1);
    EXPECT(op2 == op1);
    EXPECT(op3 != op2);
    EXPECT(op3 != op1);
Paul's avatar
Paul committed
108
109
}

Paul's avatar
Paul committed
110
111
112
struct not_operation
{
};
Paul's avatar
Paul committed
113

Paul's avatar
Paul committed
114
TEST_CASE(operation_any_cast)
Paul's avatar
Paul committed
115
{
Paul's avatar
Paul committed
116
117
118
119
120
121
122
    migraphx::operation op1 = simple_operation{};
    EXPECT(migraphx::any_cast<simple_operation>(op1).data == 1);
    EXPECT(migraphx::any_cast<not_operation*>(&op1) == nullptr);
    EXPECT(test::throws([&] { migraphx::any_cast<not_operation&>(op1); }));
    migraphx::operation op2 = simple_operation{2};
    EXPECT(migraphx::any_cast<simple_operation>(op2).data == 2);
    EXPECT(migraphx::any_cast<not_operation*>(&op2) == nullptr);
Paul's avatar
Paul committed
123
124
}

Paul's avatar
Paul committed
125
TEST_CASE(operation_print)
Paul's avatar
Paul committed
126
{
Paul's avatar
Paul committed
127
    migraphx::operation op = simple_operation{};
Paul's avatar
Paul committed
128
129
130
    std::stringstream ss;
    ss << op;
    std::string s = ss.str();
Paul's avatar
Paul committed
131
    EXPECT(s == "simple[1]");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
TEST_CASE(operation_default_print)
Paul's avatar
Paul committed
135
{
Paul's avatar
Paul committed
136
    migraphx::operation op = simple_operation_no_print{};
Paul's avatar
Paul committed
137
138
139
140
141
142
    std::stringstream ss;
    ss << op;
    std::string s = ss.str();
    EXPECT(s == "simple");
}

Paul's avatar
Paul committed
143
144
145
146
147
148
149
struct final_operation
{
    std::string name() const { return "final"; }
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
    {
        MIGRAPHX_THROW("not computable");
    }
Paul's avatar
Paul committed
150
151
152
153
    void
    finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
    {
    }
Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
162
};

struct final_operation_throw
{
    std::string name() const { return "final"; }
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
    {
        MIGRAPHX_THROW("not computable");
    }
Paul's avatar
Paul committed
163
164
    [[gnu::noreturn]] void
    finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
Paul's avatar
Paul committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    {
        MIGRAPHX_THROW("finalize");
    }
};

TEST_CASE(check_has_finalize_simple)
{
    migraphx::operation op = simple_operation{};
    EXPECT(not migraphx::has_finalize(op));
}

TEST_CASE(check_has_finalize)
{
    migraphx::operation op = final_operation{};
    EXPECT(migraphx::has_finalize(op));
}

TEST_CASE(check_run_finalize)
{
    migraphx::operation op = final_operation{};
    migraphx::context ctx{};
    op.finalize(ctx, {}, {});
}

TEST_CASE(check_run_finalize_simple)
{
    migraphx::operation op = simple_operation{};
    migraphx::context ctx{};
    op.finalize(ctx, {}, {});
}

TEST_CASE(check_run_finalize_throw)
{
    migraphx::operation op = final_operation_throw{};
    migraphx::context ctx{};
Paul's avatar
Paul committed
200
    EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
Paul's avatar
Paul committed
201
202
}

203
204
205
206
207
208
209
210
211
212
213
TEST_CASE(check_to_value1)
{
    migraphx::operation op = simple_operation{};
    auto v                 = op.to_value();
    EXPECT(v == migraphx::value{{"data", 1}});
}

TEST_CASE(check_to_value2)
{
    migraphx::operation op = simple_operation{};
    auto v                 = migraphx::to_value(op);
214
    EXPECT(v == migraphx::value{{"name", "simple"}, {"operator", {{"data", 1}}}});
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
}

TEST_CASE(check_from_value1)
{
    migraphx::operation op1 = simple_operation{};
    migraphx::operation op2 = simple_operation{3};

    op1.from_value({{"data", 3}});
    EXPECT(op1 == op2);
}

TEST_CASE(check_from_value2)
{
    migraphx::operation op1 = migraphx::from_value<simple_operation>({{"data", 3}});
    migraphx::operation op2 = simple_operation{3};

    EXPECT(op1 == op2);
}

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
TEST_CASE(compile)
{
    migraphx::operation op = compilable_op{};
    migraphx::context ctx{};
    auto v = op.compile(ctx, {}, {});
    EXPECT(v.at("compiled").to<bool>() == true);
}

TEST_CASE(compile_non_compilable)
{
    migraphx::operation op = simple_operation{};
    migraphx::context ctx{};
    auto v = op.compile(ctx, {}, {});
    EXPECT(v.empty());
}

Paul's avatar
Paul committed
250
int main(int argc, const char* argv[]) { test::run(argc, argv); }