instruction.hpp 4.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP

#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
Paul's avatar
Paul committed
8
#include <migraph/operation.hpp>
Paul's avatar
Paul committed
9
#include <migraph/erase.hpp>
Paul's avatar
Paul committed
10
#include <string>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
namespace migraph {
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
15
shape compute_shape(operation op, std::vector<instruction_ref> args);

Paul's avatar
Paul committed
16
17
struct instruction
{
Paul's avatar
Paul committed
18
19
    instruction() {}

Paul's avatar
Paul committed
20
    instruction(operation o, shape r, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
21
        : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
Paul's avatar
Paul committed
22
23
    {
    }
Paul's avatar
Paul committed
24

Paul's avatar
Paul committed
25
    instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
Paul's avatar
Paul committed
26

Paul's avatar
Paul committed
27
28
29
30
31
32
33
34
35
36
37
38
    void replace(operation o, shape r, std::vector<instruction_ref> args)
    {
        op = o;
        replace(std::move(r));
        replace(std::move(args));
    }

    void replace(shape r)
    {
        if(r != result)
        {
            result = r;
Paul's avatar
Paul committed
39
            for(auto&& ins : output)
Paul's avatar
Paul committed
40
            {
Paul's avatar
Paul committed
41
                assert(ins->op.name().front() != '@');
Paul's avatar
Paul committed
42
                ins->recompute_shape();
Paul's avatar
Paul committed
43
44
45
46
            }
        }
    }

Paul's avatar
Paul committed
47
    void recompute_shape() { replace(compute_shape(op, arguments)); }
Paul's avatar
Paul committed
48

Paul's avatar
Paul committed
49
50
51
52
53
54
    void replace(std::vector<instruction_ref> args)
    {
        clear_arguments();
        arguments = std::move(args);
    }

Paul's avatar
Paul committed
55
56
57
    void replace_argument(instruction_ref old, instruction_ref new_ins)
    {
        std::replace(arguments.begin(), arguments.end(), old, new_ins);
Paul's avatar
Paul committed
58
        old->remove_output(*this);
Paul's avatar
Paul committed
59
60
    }

Paul's avatar
Paul committed
61
62
    void clear_arguments()
    {
Paul's avatar
Paul committed
63
        for(auto&& arg : arguments)
Paul's avatar
Paul committed
64
        {
Paul's avatar
Paul committed
65
            arg->remove_output(*this);
Paul's avatar
Paul committed
66
        }
Paul's avatar
Paul committed
67
        arguments.clear();
Paul's avatar
Paul committed
68
69
70
71
72
73
74
    }

    friend bool operator==(const instruction& i, instruction_ref ref)
    {
        return std::addressof(i) == std::addressof(*ref);
    }

Paul's avatar
Paul committed
75
    bool valid(instruction_ref start) const
Paul's avatar
Paul committed
76
    {
Paul's avatar
Paul committed
77
        return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
Paul's avatar
Paul committed
78
79
80
81
82
83
84
                   auto self = std::find(i->output.begin(), i->output.end(), *this);
                   return self != i->output.end() &&
                          std::distance(start, i) < std::distance(start, *self);
               });
    }

    bool valid() const
Paul's avatar
Paul committed
85
    {
Paul's avatar
Paul committed
86
        shape computed;
Paul's avatar
Paul committed
87
        if(op.name() == "@literal")
Paul's avatar
Paul committed
88
        {
Paul's avatar
Paul committed
89
            computed = lit.get_shape();
Paul's avatar
Paul committed
90
        }
Paul's avatar
Paul committed
91
        else if(op.name() == "@param")
Paul's avatar
Paul committed
92
        {
Paul's avatar
Paul committed
93
94
            computed = result;
        }
Paul's avatar
Paul committed
95
96
        else
        {
Paul's avatar
Paul committed
97
98
99
100
101
102
103
104
            try
            {
                computed = compute_shape(op, arguments);
            }
            catch(migraph::exception&)
            {
                return false;
            }
Paul's avatar
Paul committed
105
        }
Paul's avatar
Paul committed
106
        return result == computed &&
Paul's avatar
Paul committed
107
108
109
110
               std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
                   return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
                          i->arguments.end();
               });
Paul's avatar
Paul committed
111
112
    }

wsttiger's avatar
wsttiger committed
113
    shape get_shape() const { return result; }
114

Paul's avatar
Paul committed
115
    friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
116

Paul's avatar
Paul committed
117
    friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
118

Paul's avatar
Paul committed
119
    friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
120

Paul's avatar
Paul committed
121
122
123
124
125
126
    void add_output(instruction_ref ins)
    {
        if(std::find(output.begin(), output.end(), ins) == output.end())
            output.push_back(ins);
    }

Paul's avatar
Paul committed
127
    template <class T>
Paul's avatar
Paul committed
128
129
130
131
132
    void remove_output(const T& ins)
    {
        migraph::erase(output, ins);
    }

Paul's avatar
Paul committed
133
    operation op;
Paul's avatar
Paul committed
134
    shape result;
Paul's avatar
Paul committed
135
    std::vector<instruction_ref> output;
Paul's avatar
Paul committed
136
    std::vector<instruction_ref> arguments;
Paul's avatar
Paul committed
137
    literal lit;
Paul's avatar
Paul committed
138
139
};

Paul's avatar
Paul committed
140
141
142
inline void backreference(instruction_ref ref)
{
    for(auto&& arg : ref->arguments)
Paul's avatar
Paul committed
143
        arg->add_output(ref);
Paul's avatar
Paul committed
144
145
}

Paul's avatar
Paul committed
146
147
148
149
150
151
152
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
    ins->replace_argument(old, new_ins);
    backreference(ins);
    ins->recompute_shape();
}

Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
161
162
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
    return op.compute_shape(shapes);
}

Paul's avatar
Paul committed
163
} // namespace migraph
Paul's avatar
Paul committed
164
165

#endif