basic_ops.hpp 5.15 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/program.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape.hpp>
Paul's avatar
Paul committed
4
5
6
7

struct sum_op
{
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
8
9
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
10
    {
Paul's avatar
Paul committed
11
        migraphx::argument result;
Paul's avatar
Paul committed
12
        if(args.size() != 2)
Paul's avatar
Paul committed
13
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
14
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
15
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
16
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
17
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
18
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
19
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
20
21

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
22
            args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
23
24
25
26
        });
        return result;
    }

Paul's avatar
Paul committed
27
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
28
29
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
30
            MIGRAPHX_THROW("Wrong inputs");
Paul's avatar
Paul committed
31
32
33
34
35
36
37
        return inputs.front();
    }
};

struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
38
39
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
        migraphx::argument result;
Paul's avatar
Paul committed
42
        if(args.size() != 2)
Paul's avatar
Paul committed
43
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
44
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
45
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
46
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
47
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
48
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
49
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
50
51

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
52
            args[1].visit_at([&](auto y) { result = migraphx::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
53
54
55
56
        });
        return result;
    }

Paul's avatar
Paul committed
57
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
58
59
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
60
            MIGRAPHX_THROW("Wrong inputs");
Paul's avatar
Paul committed
61
62
63
        return inputs.front();
    }
};
Paul's avatar
Paul committed
64
65
66
67

struct pass_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
68
69
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
70
71
72
73
74
75
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
76
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
77
78
79
80
81
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul Fultz II's avatar
Paul Fultz II committed
82
    int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
Shucai Xiao's avatar
Shucai Xiao committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
};

struct mod_pass_op
{
    std::string name() const { return "mod_pass"; }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs,
                                  std::vector<migraphx::module_ref> mods) const
    {
        if(!mods.empty())
        {
            auto out_shapes = mods[0]->get_output_shapes();
            return out_shapes[0];
        }
        if(!inputs.empty())
        {
            return inputs.front();
        }

        return {};
    }

Paul's avatar
Paul committed
105
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
106
};
107

Paul's avatar
Paul committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
struct unary_pass_op
{
    std::string name() const { return "unary_pass"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.size() != 1)
            MIGRAPHX_THROW("Wrong inputs");
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

128
129
130
struct pass_standard_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
131
132
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
133
134
135
136
137
138
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
139
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
140
    {
Paul's avatar
Paul committed
141
        for(auto&& input : inputs)
142
143
144
145
146
147
148
149
        {
            if(not input.standard())
                throw std::runtime_error("Not standard shape");
        }
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul's avatar
Paul committed
150
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
151
152
};

153
154
155
struct nop
{
    std::string name() const { return "nop"; }
Paul's avatar
Paul committed
156
157
158
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
159
160
161
162
    {
        return {};
    }

Paul's avatar
Paul committed
163
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
164
};
165

Paul's avatar
Paul committed
166
inline migraphx::literal get_2x2()
167
{
Paul's avatar
Paul committed
168
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
169
170
}

Paul's avatar
Paul committed
171
inline migraphx::literal get_2x2_transposed()
172
{
Paul's avatar
Paul committed
173
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
174
175
}

Paul's avatar
Paul committed
176
inline migraphx::literal get_2()
Paul's avatar
Paul committed
177
{
Paul's avatar
Paul committed
178
    return migraphx::literal{{migraphx::shape::float_type, {2}}, {1, 2}};
Paul's avatar
Paul committed
179
}
180

Paul's avatar
Paul committed
181
inline migraphx::literal get_2_broadcasted()
182
{
Paul's avatar
Paul committed
183
    return migraphx::literal{{migraphx::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
184
}