"src/targets/vscode:/vscode.git/clone" did not exist on "75f5ed4ac1bcb5994f0d5e3e2d34790791e3d6a0"
instruction.hpp 5.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP

#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
Paul's avatar
Paul committed
8
#include <migraph/operation.hpp>
Paul's avatar
Paul committed
9
#include <migraph/erase.hpp>
Paul's avatar
Paul committed
10
#include <string>
Paul's avatar
Paul committed
11
#include <utility>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraph {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
18
struct instruction
{
Paul's avatar
Paul committed
19
20
    instruction() {}

Paul's avatar
Paul committed
21
    instruction(operation o, shape r, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
22
        : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
Paul's avatar
Paul committed
23
24
    {
    }
Paul's avatar
Paul committed
25

Paul's avatar
Paul committed
26
    instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
Paul's avatar
Paul committed
27

Paul's avatar
Paul committed
28
    void replace(const shape& r)
Paul's avatar
Paul committed
29
30
31
32
    {
        if(r != result)
        {
            result = r;
Paul's avatar
Paul committed
33
            for(auto&& ins : output)
Paul's avatar
Paul committed
34
            {
Paul's avatar
Paul committed
35
                assert(ins->name().front() != '@');
Paul's avatar
Paul committed
36
                ins->recompute_shape();
Paul's avatar
Paul committed
37
38
39
40
            }
        }
    }

Paul's avatar
Paul committed
41
    void recompute_shape() { replace(compute_shape(op, arguments)); }
Paul's avatar
Paul committed
42

Paul's avatar
Paul committed
43
44
    void clear_arguments()
    {
Paul's avatar
Paul committed
45
        for(auto&& arg : arguments)
Paul's avatar
Paul committed
46
        {
Paul's avatar
Paul committed
47
            arg->remove_output(*this);
Paul's avatar
Paul committed
48
        }
Paul's avatar
Paul committed
49
        arguments.clear();
Paul's avatar
Paul committed
50
51
52
53
54
55
56
    }

    friend bool operator==(const instruction& i, instruction_ref ref)
    {
        return std::addressof(i) == std::addressof(*ref);
    }

Paul's avatar
Paul committed
57
    bool valid(instruction_ref start) const
Paul's avatar
Paul committed
58
    {
Paul's avatar
Paul committed
59
        return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
Paul's avatar
Paul committed
60
61
                   auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
                   return self != i->outputs().end() &&
Paul's avatar
Paul committed
62
63
64
65
66
                          std::distance(start, i) < std::distance(start, *self);
               });
    }

    bool valid() const
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
        shape computed;
Paul's avatar
Paul committed
69
        if(op.name() == "@literal")
Paul's avatar
Paul committed
70
        {
Paul's avatar
Paul committed
71
            computed = lit.get_shape();
Paul's avatar
Paul committed
72
        }
Paul's avatar
Paul committed
73
        else if(op.name() == "@param")
Paul's avatar
Paul committed
74
        {
Paul's avatar
Paul committed
75
76
            computed = result;
        }
Paul's avatar
Paul committed
77
78
        else
        {
Paul's avatar
Paul committed
79
80
81
82
83
84
85
86
            try
            {
                computed = compute_shape(op, arguments);
            }
            catch(migraph::exception&)
            {
                return false;
            }
Paul's avatar
Paul committed
87
        }
Paul's avatar
Paul committed
88
        return result == computed &&
Paul's avatar
Paul committed
89
               std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
Paul's avatar
Paul committed
90
91
                   return std::find(i->inputs().begin(), i->inputs().end(), *this) !=
                          i->inputs().end();
Paul's avatar
Paul committed
92
               });
Paul's avatar
Paul committed
93
94
    }

wsttiger's avatar
wsttiger committed
95
    shape get_shape() const { return result; }
Paul's avatar
Paul committed
96
    const literal& get_literal() const
97
98
    {
        assert(op.name() == "@literal");
Paul's avatar
Paul committed
99
        return lit;
100
    }
101

Paul's avatar
Paul committed
102
    const operation& get_operator() const { return op; }
Paul's avatar
Paul committed
103

Paul's avatar
Paul committed
104
    std::string name() const { return op.name(); }
Paul's avatar
Paul committed
105

Paul's avatar
Paul committed
106
    const std::vector<instruction_ref>& inputs() const { return arguments; }
Paul's avatar
Paul committed
107

Paul's avatar
Paul committed
108
    const std::vector<instruction_ref>& outputs() const { return output; }
Paul's avatar
Paul committed
109

Paul's avatar
Paul committed
110
    friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
111

Paul's avatar
Paul committed
112
    friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
113

Paul's avatar
Paul committed
114
    friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
115

Paul's avatar
Paul committed
116
117
118
119
120
121
    void add_output(instruction_ref ins)
    {
        if(std::find(output.begin(), output.end(), ins) == output.end())
            output.push_back(ins);
    }

Paul's avatar
Paul committed
122
    template <class T>
Paul's avatar
Paul committed
123
124
125
126
127
    void remove_output(const T& ins)
    {
        migraph::erase(output, ins);
    }

Paul's avatar
Paul committed
128
    static void backreference(instruction_ref ref)
Paul's avatar
Paul committed
129
    {
Paul's avatar
Paul committed
130
131
        for(auto&& arg : ref->inputs())
            arg->add_output(ref);
Paul's avatar
Paul committed
132
133
    }

Paul's avatar
Paul committed
134
    static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
Paul's avatar
Paul committed
135
    {
Paul's avatar
Paul committed
136
137
138
        ins->replace_argument(old, new_ins);
        backreference(ins);
        ins->recompute_shape();
Paul's avatar
Paul committed
139
    }
Paul's avatar
Paul committed
140

141
142
143
144
145
146
    static void replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args)
    {
        ins->replace(o, r, std::move(args));
        backreference(ins);
    }
private:
Paul's avatar
Paul committed
147
148
149
150
151
152
153
154
    // internal
    void replace(operation o, const shape& r, std::vector<instruction_ref> args)
    {
        op = std::move(o);
        replace(r);
        replace(std::move(args));
    }

Paul's avatar
Paul committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    // internal
    void replace(std::vector<instruction_ref> args)
    {
        clear_arguments();
        arguments = std::move(args);
    }

    // internal
    void replace_argument(instruction_ref old, instruction_ref new_ins)
    {
        std::replace(arguments.begin(), arguments.end(), old, new_ins);
        old->remove_output(*this);
    }

Paul's avatar
Paul committed
169
    operation op;
Paul's avatar
Paul committed
170
    shape result;
Paul's avatar
Paul committed
171
    std::vector<instruction_ref> output;
Paul's avatar
Paul committed
172
    std::vector<instruction_ref> arguments;
Paul's avatar
Paul committed
173
    literal lit;
Paul's avatar
Paul committed
174
175
};

Paul's avatar
Paul committed
176
// TODO: Move to a cpp file
Paul's avatar
Paul committed
177
inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
Paul's avatar
Paul committed
178
179
180
{
    std::vector<shape> shapes(args.size());
    std::transform(
Paul's avatar
Paul committed
181
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
Paul's avatar
Paul committed
182
183
184
    return op.compute_shape(shapes);
}

Paul's avatar
Paul committed
185
} // namespace migraph
Paul's avatar
Paul committed
186

Paul's avatar
Paul committed
187
188
189
namespace std {
template <>
struct hash<migraph::instruction_ref>
190
{
Paul's avatar
Paul committed
191
192
193
    using argument_type = migraph::instruction_ref;
    using result_type   = std::size_t;
    result_type operator()(const argument_type& x) const noexcept
194
    {
Paul's avatar
Paul committed
195
196
197
        return std::hash<migraph::instruction*>{}(&*x);
    }
};
198
199
} // namespace std

Paul's avatar
Paul committed
200
#endif