instruction.cpp 4.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>

namespace migraph {

Paul's avatar
Paul committed
7
8
9
10
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
    : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
instruction::instruction(literal l)
    : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
18
19
void instruction::replace(const shape& r)
{
    if(r != result)
Paul's avatar
Paul committed
20
    {
Paul's avatar
Paul committed
21
22
        result = r;
        for(auto&& ins : output)
Paul's avatar
Paul committed
23
        {
Paul's avatar
Paul committed
24
25
            assert(ins->name().front() != '@');
            ins->recompute_shape();
Paul's avatar
Paul committed
26
27
        }
    }
Paul's avatar
Paul committed
28
}
Paul's avatar
Paul committed
29

Paul's avatar
Paul committed
30
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
Paul's avatar
Paul committed
31

Paul's avatar
Paul committed
32
33
34
void instruction::clear_arguments()
{
    for(auto&& arg : arguments)
Paul's avatar
Paul committed
35
    {
Paul's avatar
Paul committed
36
        arg->remove_output(*this);
Paul's avatar
Paul committed
37
    }
Paul's avatar
Paul committed
38
39
40
41
42
43
44
    arguments.clear();
}

bool operator==(const instruction& i, instruction_ref ref)
{
    return std::addressof(i) == std::addressof(*ref);
}
Paul's avatar
Paul committed
45

Paul's avatar
Paul committed
46
47
48
49
50
51
52
53
54
55
56
57
58
bool instruction::valid(instruction_ref start) const
{
    return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
               auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
               return self != i->outputs().end() &&
                      std::distance(start, i) < std::distance(start, *self);
           });
}

bool instruction::valid() const
{
    shape computed;
    if(op.name() == "@literal")
Paul's avatar
Paul committed
59
    {
Paul's avatar
Paul committed
60
        computed = lit.get_shape();
Paul's avatar
Paul committed
61
    }
Paul's avatar
Paul committed
62
    else if(op.name() == "@param")
Paul's avatar
Paul committed
63
    {
Paul's avatar
Paul committed
64
        computed = result;
Paul's avatar
Paul committed
65
    }
Paul's avatar
Paul committed
66
    else
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
        try
Paul's avatar
Paul committed
69
        {
Paul's avatar
Paul committed
70
            computed = compute_shape(op, arguments);
Paul's avatar
Paul committed
71
        }
Paul's avatar
Paul committed
72
        catch(migraph::exception&)
Paul's avatar
Paul committed
73
        {
Paul's avatar
Paul committed
74
            return false;
Paul's avatar
Paul committed
75
76
        }
    }
Paul's avatar
Paul committed
77
78
79
80
    return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
               return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
           });
}
Paul's avatar
Paul committed
81

Paul's avatar
Paul committed
82
83
84
85
86
87
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
    assert(op.name() == "@literal");
    return lit;
}
Paul's avatar
Paul committed
88

Paul's avatar
Paul committed
89
const operation& instruction::get_operator() const { return op; }
Paul's avatar
Paul committed
90

Paul's avatar
Paul committed
91
std::string instruction::name() const { return op.name(); }
Paul's avatar
Paul committed
92

Paul's avatar
Paul committed
93
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
Paul's avatar
Paul committed
94

Paul's avatar
Paul committed
95
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
Paul's avatar
Paul committed
96

Paul's avatar
Paul committed
97
98
bool operator==(const instruction& x, const instruction& y)
{
Paul's avatar
Paul committed
99
    if(not(x.result == y.result and x.op == y.op and x.arguments == y.arguments))
Paul's avatar
Paul committed
100
101
102
103
104
        return false;
    if(x.name() == "@literal")
        return x.lit == y.lit;
    return true;
}
Paul's avatar
Paul committed
105

Paul's avatar
Paul committed
106
107
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }

Paul's avatar
Paul committed
108
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
109

Paul's avatar
Paul committed
110
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
111

Paul's avatar
Paul committed
112
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
113

Paul's avatar
Paul committed
114
115
116
117
118
void instruction::add_output(instruction_ref ins)
{
    if(std::find(output.begin(), output.end(), ins) == output.end())
        output.push_back(ins);
}
Paul's avatar
Paul committed
119

Paul's avatar
Paul committed
120
121
122
123
124
template <class T>
void instruction::remove_output(const T& ins)
{
    migraph::erase(output, ins);
}
Paul's avatar
Paul committed
125

Paul's avatar
Paul committed
126
127
128
129
130
void instruction::backreference(instruction_ref ref)
{
    for(auto&& arg : ref->inputs())
        arg->add_output(ref);
}
Paul's avatar
Paul committed
131

Paul's avatar
Paul committed
132
133
134
135
136
137
138
139
void instruction::replace_argument(instruction_ref ins,
                                   instruction_ref old,
                                   instruction_ref new_ins)
{
    ins->replace_argument(old, new_ins);
    backreference(ins);
    ins->recompute_shape();
}
Paul's avatar
Paul committed
140

Paul's avatar
Paul committed
141
142
143
144
145
146
147
148
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args)
{
    ins->replace(std::move(o), r, std::move(args));
    backreference(ins);
}
Paul's avatar
Paul committed
149

Paul's avatar
Paul committed
150
151
152
153
154
155
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
    op = std::move(o);
    replace(r);
    replace(std::move(args));
}
Paul's avatar
Paul committed
156

Paul's avatar
Paul committed
157
158
159
160
161
void instruction::replace(std::vector<instruction_ref> args)
{
    clear_arguments();
    arguments = std::move(args);
}
Paul's avatar
Paul committed
162

Paul's avatar
Paul committed
163
164
165
166
167
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
    std::replace(arguments.begin(), arguments.end(), old, new_ins);
    old->remove_output(*this);
}
Paul's avatar
Paul committed
168
169
170
171
172
173
174
175
176
177

shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
    return op.compute_shape(shapes);
}

} // namespace migraph