instruction.cpp 6.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
Paul's avatar
Paul committed
4

Paul's avatar
Paul committed
5
namespace migraphx {
Paul's avatar
Paul committed
6
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
10
11
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
    : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
14
15
16
instruction::instruction(literal l)
    : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
19
20
void instruction::replace(const shape& r)
{
    if(r != result)
Paul's avatar
Paul committed
21
    {
Paul's avatar
Paul committed
22
23
        result = r;
        for(auto&& ins : output)
Paul's avatar
Paul committed
24
        {
25
26
27
            if(ins->name() == "@return")
                continue;

Paul's avatar
Paul committed
28
29
            assert(ins->name().front() != '@');
            ins->recompute_shape();
Paul's avatar
Paul committed
30
31
        }
    }
Paul's avatar
Paul committed
32
}
Paul's avatar
Paul committed
33

Paul's avatar
Paul committed
34
void instruction::replace(operation o)
Paul's avatar
Paul committed
35
36
37
38
39
{
    op = std::move(o);
    recompute_shape();
}

Paul's avatar
Paul committed
40
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
Paul's avatar
Paul committed
41

Paul's avatar
Paul committed
42
43
44
void instruction::clear_arguments()
{
    for(auto&& arg : arguments)
Paul's avatar
Paul committed
45
    {
Paul's avatar
Paul committed
46
        arg->remove_output(*this);
Paul's avatar
Paul committed
47
    }
Paul's avatar
Paul committed
48
49
50
51
52
53
54
    arguments.clear();
}

bool operator==(const instruction& i, instruction_ref ref)
{
    return std::addressof(i) == std::addressof(*ref);
}
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
57
58
59
60
61
62
63
64
65
66
67
68
bool instruction::valid(instruction_ref start) const
{
    return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
               auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
               return self != i->outputs().end() &&
                      std::distance(start, i) < std::distance(start, *self);
           });
}

bool instruction::valid() const
{
    shape computed;
    if(op.name() == "@literal")
Paul's avatar
Paul committed
69
    {
Paul's avatar
Paul committed
70
        computed = lit.get_shape();
Paul's avatar
Paul committed
71
    }
Paul's avatar
Paul committed
72
    else if(op.name() == "@param")
Paul's avatar
Paul committed
73
    {
Paul's avatar
Paul committed
74
        computed = result;
Paul's avatar
Paul committed
75
    }
76
77
78
79
    else if(op.name() == "@return")
    {
        computed = {};
    }
Paul's avatar
Paul committed
80
    else
Paul's avatar
Paul committed
81
    {
Paul's avatar
Paul committed
82
        try
Paul's avatar
Paul committed
83
        {
Paul's avatar
Paul committed
84
            computed = compute_shape(op, arguments);
Paul's avatar
Paul committed
85
        }
Paul's avatar
Paul committed
86
        catch(migraphx::exception&)
Paul's avatar
Paul committed
87
        {
Paul's avatar
Paul committed
88
            return false;
Paul's avatar
Paul committed
89
90
        }
    }
91

Paul's avatar
Paul committed
92
93
94
95
    return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
               return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
           });
}
Paul's avatar
Paul committed
96

Paul's avatar
Paul committed
97
98
99
100
101
102
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
    assert(op.name() == "@literal");
    return lit;
}
Paul's avatar
Paul committed
103

Paul's avatar
Paul committed
104
const operation& instruction::get_operator() const { return op; }
Paul's avatar
Paul committed
105

Paul's avatar
Paul committed
106
std::string instruction::name() const { return op.name(); }
Paul's avatar
Paul committed
107

Paul's avatar
Paul committed
108
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
Paul's avatar
Paul committed
109

Paul's avatar
Paul committed
110
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
Paul's avatar
Paul committed
111

Paul's avatar
Paul committed
112
113
bool operator==(const instruction& x, const instruction& y)
{
114
    if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
Paul's avatar
Paul committed
115
116
117
118
119
        return false;
    if(x.name() == "@literal")
        return x.lit == y.lit;
    return true;
}
Paul's avatar
Paul committed
120

Paul's avatar
Paul committed
121
122
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }

Paul's avatar
Paul committed
123
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
124

Paul's avatar
Paul committed
125
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
126

Paul's avatar
Paul committed
127
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
128

Paul's avatar
Paul committed
129
130
131
132
133
void instruction::add_output(instruction_ref ins)
{
    if(std::find(output.begin(), output.end(), ins) == output.end())
        output.push_back(ins);
}
Paul's avatar
Paul committed
134

Paul's avatar
Paul committed
135
136
137
138
139
void instruction::backreference(instruction_ref ref)
{
    for(auto&& arg : ref->inputs())
        arg->add_output(ref);
}
Paul's avatar
Paul committed
140

Paul's avatar
Paul committed
141
142
143
144
145
146
147
148
void instruction::replace_argument(instruction_ref ins,
                                   instruction_ref old,
                                   instruction_ref new_ins)
{
    ins->replace_argument(old, new_ins);
    backreference(ins);
    ins->recompute_shape();
}
Paul's avatar
Paul committed
149

Paul's avatar
Paul committed
150
151
152
153
154
155
156
157
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args)
{
    ins->replace(std::move(o), r, std::move(args));
    backreference(ins);
}
Paul's avatar
Paul committed
158

Paul's avatar
Paul committed
159
160
161
162
163
164
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
    op = std::move(o);
    replace(r);
    replace(std::move(args));
}
Paul's avatar
Paul committed
165

Paul's avatar
Paul committed
166
167
168
169
170
void instruction::replace(std::vector<instruction_ref> args)
{
    clear_arguments();
    arguments = std::move(args);
}
Paul's avatar
Paul committed
171

Paul's avatar
Paul committed
172
173
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
Paul's avatar
Paul committed
174
    assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
Paul's avatar
Paul committed
175
176
177
    std::replace(arguments.begin(), arguments.end(), old, new_ins);
    old->remove_output(*this);
}
Paul's avatar
Paul committed
178

Paul's avatar
Paul committed
179
180
181
182
183
184
bool instruction::can_eval() const
{
    if(op.name() == "@literal")
    {
        return true;
    }
Paul's avatar
Paul committed
185
    else if(is_context_free(op))
Paul's avatar
Paul committed
186
    {
Paul's avatar
Paul committed
187
188
        return std::all_of(
            this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
Paul's avatar
Paul committed
189
190
191
192
193
194
195
    }
    else
    {
        return false;
    }
}

Paul's avatar
Paul committed
196
argument instruction::eval(bool check_eval) const
Paul's avatar
Paul committed
197
198
199
200
201
{
    if(op.name() == "@literal")
    {
        return this->get_literal().get_argument();
    }
Paul's avatar
Paul committed
202
    if(is_context_free(op))
Paul's avatar
Paul committed
203
    {
Paul's avatar
Paul committed
204
        if(check_eval and not this->can_eval())
Paul's avatar
Paul committed
205
            return {};
Paul's avatar
Paul committed
206
        std::vector<argument> args;
Paul's avatar
Paul committed
207
208
209
        std::transform(this->inputs().begin(),
                       this->inputs().end(),
                       std::back_inserter(args),
Paul's avatar
Paul committed
210
                       [](auto arg) { return arg->eval(false); });
Paul's avatar
Paul committed
211
212
213
214
215
        return op.compute(result, args);
    }
    return {};
}

Paul's avatar
Paul committed
216
217
void instruction::finalize(context& ctx)
{
Paul's avatar
Paul committed
218
    if(has_finalize(this->op))
Paul's avatar
Paul committed
219
220
221
        this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}

Paul's avatar
Paul committed
222
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
Paul's avatar
Paul committed
223
{
Paul's avatar
Paul committed
224
    auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
Paul's avatar
Paul committed
225
226
    if(i < 0)
        return ins;
Paul's avatar
Paul committed
227
    if(shallow)
Paul's avatar
Paul committed
228
        return ins->inputs().at(i);
Paul's avatar
Paul committed
229
230
231
    return get_output_alias(ins->inputs().at(i));
}

Paul's avatar
Paul committed
232
233
234
235
236
237
238
239
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
    return shapes;
}

Paul's avatar
Paul committed
240
241
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
Paul's avatar
Paul committed
242
    return op.compute_shape(to_shapes(args));
Paul's avatar
Paul committed
243
244
}

Paul's avatar
Paul committed
245
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
246
} // namespace migraphx