basic_ops.hpp 5.13 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/program.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape.hpp>
Paul's avatar
Paul committed
4
5
6
7

struct sum_op
{
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
8
9
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
10
    {
Paul's avatar
Paul committed
11
        migraphx::argument result;
Paul's avatar
Paul committed
12
        if(args.size() != 2)
Paul's avatar
Paul committed
13
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
14
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
15
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
16
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
17
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
18
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
19
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
20
21

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
22
            args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
23
24
25
26
        });
        return result;
    }

Paul's avatar
Paul committed
27
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
28
29
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
30
            MIGRAPHX_THROW("Wrong inputs");
Paul's avatar
Paul committed
31
32
33
34
35
36
37
        return inputs.front();
    }
};

struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
38
39
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
        migraphx::argument result;
Paul's avatar
Paul committed
42
        if(args.size() != 2)
Paul's avatar
Paul committed
43
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
44
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
45
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
46
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
47
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
48
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
49
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
50
51

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
52
            args[1].visit_at([&](auto y) { result = migraphx::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
53
54
55
56
        });
        return result;
    }

Paul's avatar
Paul committed
57
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
58
59
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
60
            MIGRAPHX_THROW("Wrong inputs");
Paul's avatar
Paul committed
61
62
63
        return inputs.front();
    }
};
Paul's avatar
Paul committed
64
65
66
67

struct pass_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
68
69
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
70
71
72
73
74
75
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
76
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
77
78
79
80
81
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Shucai Xiao's avatar
Shucai Xiao committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

struct mod_pass_op
{
    std::string name() const { return "mod_pass"; }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs,
                                  std::vector<migraphx::module_ref> mods) const
    {
        if(!mods.empty())
        {
            auto out_shapes = mods[0]->get_output_shapes();
            return out_shapes[0];
        }
        if(!inputs.empty())
        {
            return inputs.front();
        }

        return {};
    }

Paul's avatar
Paul committed
106
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
107
};
108

Paul's avatar
Paul committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
struct unary_pass_op
{
    std::string name() const { return "unary_pass"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.size() != 1)
            MIGRAPHX_THROW("Wrong inputs");
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

129
130
131
struct pass_standard_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
132
133
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
134
135
136
137
138
139
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
140
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
141
    {
Paul's avatar
Paul committed
142
        for(auto&& input : inputs)
143
144
145
146
147
148
149
150
        {
            if(not input.standard())
                throw std::runtime_error("Not standard shape");
        }
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul's avatar
Paul committed
151
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
152
153
};

154
155
156
struct nop
{
    std::string name() const { return "nop"; }
Paul's avatar
Paul committed
157
158
159
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
160
161
162
163
    {
        return {};
    }

Paul's avatar
Paul committed
164
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
165
};
166

Paul's avatar
Paul committed
167
inline migraphx::literal get_2x2()
168
{
Paul's avatar
Paul committed
169
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
170
171
}

Paul's avatar
Paul committed
172
inline migraphx::literal get_2x2_transposed()
173
{
Paul's avatar
Paul committed
174
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
175
176
}

Paul's avatar
Paul committed
177
inline migraphx::literal get_2()
Paul's avatar
Paul committed
178
{
Paul's avatar
Paul committed
179
    return migraphx::literal{{migraphx::shape::float_type, {2}}, {1, 2}};
Paul's avatar
Paul committed
180
}
181

Paul's avatar
Paul committed
182
inline migraphx::literal get_2_broadcasted()
183
{
Paul's avatar
Paul committed
184
    return migraphx::literal{{migraphx::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
185
}