basic_ops.hpp 4.58 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/program.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape.hpp>
Paul's avatar
Paul committed
4
5
6
7

struct sum_op
{
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
8
9
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
10
    {
Paul's avatar
Paul committed
11
        migraphx::argument result;
Paul's avatar
Paul committed
12
        if(args.size() != 2)
Paul's avatar
Paul committed
13
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
14
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
15
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
16
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
17
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
18
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
19
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
20
21

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
22
            args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
23
24
25
26
        });
        return result;
    }

Paul's avatar
Paul committed
27
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
28
29
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
30
            MIGRAPHX_THROW("Wrong inputs");
Paul's avatar
Paul committed
31
32
33
34
35
36
37
        return inputs.front();
    }
};

struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
38
39
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
        migraphx::argument result;
Paul's avatar
Paul committed
42
        if(args.size() != 2)
Paul's avatar
Paul committed
43
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
44
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
45
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
46
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
47
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
48
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
49
            MIGRAPHX_THROW("Wrong args");
Paul's avatar
Paul committed
50
51

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
52
            args[1].visit_at([&](auto y) { result = migraphx::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
53
54
55
56
        });
        return result;
    }

Paul's avatar
Paul committed
57
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
58
59
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
60
            MIGRAPHX_THROW("Wrong inputs");
Paul's avatar
Paul committed
61
62
63
        return inputs.front();
    }
};
Paul's avatar
Paul committed
64
65
66
67

struct pass_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
68
69
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
70
71
72
73
74
75
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
76
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
77
78
79
80
81
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul's avatar
Paul committed
82
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
83
};
84

Paul's avatar
Paul committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
struct unary_pass_op
{
    std::string name() const { return "unary_pass"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.size() != 1)
            MIGRAPHX_THROW("Wrong inputs");
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

105
106
107
struct pass_standard_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
108
109
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
110
111
112
113
114
115
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
116
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
117
    {
Paul's avatar
Paul committed
118
        for(auto&& input : inputs)
119
120
121
122
123
124
125
126
        {
            if(not input.standard())
                throw std::runtime_error("Not standard shape");
        }
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul's avatar
Paul committed
127
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
128
129
};

130
131
132
struct nop
{
    std::string name() const { return "nop"; }
Paul's avatar
Paul committed
133
134
135
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
136
137
138
139
    {
        return {};
    }

Paul's avatar
Paul committed
140
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
141
};
142

Paul's avatar
Paul committed
143
inline migraphx::literal get_2x2()
144
{
Paul's avatar
Paul committed
145
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
146
147
}

Paul's avatar
Paul committed
148
inline migraphx::literal get_2x2_transposed()
149
{
Paul's avatar
Paul committed
150
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
151
152
}

Paul's avatar
Paul committed
153
inline migraphx::literal get_2()
Paul's avatar
Paul committed
154
{
Paul's avatar
Paul committed
155
    return migraphx::literal{{migraphx::shape::float_type, {2}}, {1, 2}};
Paul's avatar
Paul committed
156
}
157

Paul's avatar
Paul committed
158
inline migraphx::literal get_2_broadcasted()
159
{
Paul's avatar
Paul committed
160
    return migraphx::literal{{migraphx::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
161
}