basic_ops.hpp 4.02 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/program.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape.hpp>
Paul's avatar
Paul committed
4
5
6
7

struct sum_op
{
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
8
9
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
10
    {
Paul's avatar
Paul committed
11
        migraphx::argument result;
Paul's avatar
Paul committed
12
13
14
15
16
17
18
19
20
21
        if(args.size() != 2)
            MIGRAPH_THROW("Wrong args");
        if(args[0].get_shape() != args[1].get_shape())
            MIGRAPH_THROW("Wrong args");
        if(args[0].get_shape().lens().size() != 1)
            MIGRAPH_THROW("Wrong args");
        if(args[0].get_shape().lens().front() != 1)
            MIGRAPH_THROW("Wrong args");

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
22
            args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
23
24
25
26
        });
        return result;
    }

Paul's avatar
Paul committed
27
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
28
29
30
31
32
33
34
35
36
37
    {
        if(inputs.size() != 2)
            MIGRAPH_THROW("Wrong inputs");
        return inputs.front();
    }
};

struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
38
39
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
        migraphx::argument result;
Paul's avatar
Paul committed
42
43
44
45
46
47
48
49
50
51
        if(args.size() != 2)
            MIGRAPH_THROW("Wrong args");
        if(args[0].get_shape() != args[1].get_shape())
            MIGRAPH_THROW("Wrong args");
        if(args[0].get_shape().lens().size() != 1)
            MIGRAPH_THROW("Wrong args");
        if(args[0].get_shape().lens().front() != 1)
            MIGRAPH_THROW("Wrong args");

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
52
            args[1].visit_at([&](auto y) { result = migraphx::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
53
54
55
56
        });
        return result;
    }

Paul's avatar
Paul committed
57
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
58
59
60
61
62
63
    {
        if(inputs.size() != 2)
            MIGRAPH_THROW("Wrong inputs");
        return inputs.front();
    }
};
Paul's avatar
Paul committed
64
65
66
67

struct pass_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
68
69
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
70
71
72
73
74
75
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
76
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
Paul's avatar
Paul committed
77
78
79
80
81
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul's avatar
Paul committed
82
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
83
};
84

85
86
87
struct pass_standard_op
{
    std::string name() const { return "pass"; }
Paul's avatar
Paul committed
88
89
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
90
91
92
93
94
95
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
96
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
97
    {
Paul's avatar
Paul committed
98
        for(auto&& input : inputs)
99
100
101
102
103
104
105
106
        {
            if(not input.standard())
                throw std::runtime_error("Not standard shape");
        }
        if(inputs.empty())
            return {};
        return inputs.front();
    }
Paul's avatar
Paul committed
107
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
108
109
};

110
111
112
struct nop
{
    std::string name() const { return "nop"; }
Paul's avatar
Paul committed
113
114
115
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
116
117
118
119
    {
        return {};
    }

Paul's avatar
Paul committed
120
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
121
};
122

Paul's avatar
Paul committed
123
inline migraphx::literal get_2x2()
124
{
Paul's avatar
Paul committed
125
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
126
127
}

Paul's avatar
Paul committed
128
inline migraphx::literal get_2x2_transposed()
129
{
Paul's avatar
Paul committed
130
    return migraphx::literal{{migraphx::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
131
132
}

Paul's avatar
Paul committed
133
inline migraphx::literal get_2()
Paul's avatar
Paul committed
134
{
Paul's avatar
Paul committed
135
    return migraphx::literal{{migraphx::shape::float_type, {2}}, {1, 2}};
Paul's avatar
Paul committed
136
}
137

Paul's avatar
Paul committed
138
inline migraphx::literal get_2_broadcasted()
139
{
Paul's avatar
Paul committed
140
    return migraphx::literal{{migraphx::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
141
}