instruction.cpp 6.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
Paul's avatar
Paul committed
4

Paul's avatar
Paul committed
5
namespace migraphx {
Paul's avatar
Paul committed
6
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
10
11
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
    : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
14
15
16
instruction::instruction(literal l)
    : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
19
20
void instruction::replace(const shape& r)
{
    if(r != result)
Paul's avatar
Paul committed
21
    {
Paul's avatar
Paul committed
22
23
        result = r;
        for(auto&& ins : output)
Paul's avatar
Paul committed
24
        {
25
26
27
            if(ins->name() == "@return")
                continue;

Paul's avatar
Paul committed
28
29
            assert(ins->name().front() != '@');
            ins->recompute_shape();
Paul's avatar
Paul committed
30
31
        }
    }
Paul's avatar
Paul committed
32
}
Paul's avatar
Paul committed
33

Paul's avatar
Paul committed
34
void instruction::replace(operation o)
Paul's avatar
Paul committed
35
36
37
38
39
{
    op = std::move(o);
    recompute_shape();
}

Paul's avatar
Paul committed
40
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
Paul's avatar
Paul committed
41

Paul's avatar
Paul committed
42
43
44
void instruction::clear_arguments()
{
    for(auto&& arg : arguments)
Paul's avatar
Paul committed
45
    {
Paul's avatar
Paul committed
46
        arg->remove_output(*this);
Paul's avatar
Paul committed
47
    }
Paul's avatar
Paul committed
48
49
50
51
52
53
54
    arguments.clear();
}

bool operator==(const instruction& i, instruction_ref ref)
{
    return std::addressof(i) == std::addressof(*ref);
}
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
57
58
59
60
61
62
63
64
65
66
67
68
bool instruction::valid(instruction_ref start) const
{
    return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
               auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
               return self != i->outputs().end() &&
                      std::distance(start, i) < std::distance(start, *self);
           });
}

bool instruction::valid() const
{
    shape computed;
    if(op.name() == "@literal")
Paul's avatar
Paul committed
69
    {
Paul's avatar
Paul committed
70
        computed = lit.get_shape();
Paul's avatar
Paul committed
71
    }
Paul's avatar
Paul committed
72
    else if(op.name() == "@param")
Paul's avatar
Paul committed
73
    {
Paul's avatar
Paul committed
74
        computed = result;
Paul's avatar
Paul committed
75
    }
76
77
78
79
    else if(op.name() == "@return")
    {
        computed = {};
    }
Paul's avatar
Paul committed
80
    else
Paul's avatar
Paul committed
81
    {
Paul's avatar
Paul committed
82
        try
Paul's avatar
Paul committed
83
        {
Paul's avatar
Paul committed
84
            computed = compute_shape(op, arguments);
Paul's avatar
Paul committed
85
        }
Paul's avatar
Paul committed
86
        catch(migraphx::exception&)
Paul's avatar
Paul committed
87
        {
Paul's avatar
Paul committed
88
            return false;
Paul's avatar
Paul committed
89
90
        }
    }
91

Paul's avatar
Paul committed
92
93
94
95
    return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
               return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
           });
}
Paul's avatar
Paul committed
96

Paul's avatar
Paul committed
97
98
99
100
101
102
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
    assert(op.name() == "@literal");
    return lit;
}
Paul's avatar
Paul committed
103

Paul's avatar
Paul committed
104
const operation& instruction::get_operator() const { return op; }
Paul's avatar
Paul committed
105

Paul's avatar
Paul committed
106
std::string instruction::name() const { return op.name(); }
Paul's avatar
Paul committed
107

Paul's avatar
Paul committed
108
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
Paul's avatar
Paul committed
109

Paul's avatar
Paul committed
110
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
Paul's avatar
Paul committed
111

Paul's avatar
Paul committed
112
113
bool operator==(const instruction& x, const instruction& y)
{
114
    if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
Paul's avatar
Paul committed
115
116
117
118
119
        return false;
    if(x.name() == "@literal")
        return x.lit == y.lit;
    return true;
}
Paul's avatar
Paul committed
120

Paul's avatar
Paul committed
121
122
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }

Paul's avatar
Paul committed
123
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
124

Paul's avatar
Paul committed
125
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
126

Paul's avatar
Paul committed
127
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
128

Paul's avatar
Paul committed
129
130
131
132
133
void instruction::add_output(instruction_ref ins)
{
    if(std::find(output.begin(), output.end(), ins) == output.end())
        output.push_back(ins);
}
Paul's avatar
Paul committed
134

Paul's avatar
Paul committed
135
136
137
138
139
void instruction::backreference(instruction_ref ref)
{
    for(auto&& arg : ref->inputs())
        arg->add_output(ref);
}
Paul's avatar
Paul committed
140

Paul's avatar
Paul committed
141
142
143
144
145
146
147
148
void instruction::replace_argument(instruction_ref ins,
                                   instruction_ref old,
                                   instruction_ref new_ins)
{
    ins->replace_argument(old, new_ins);
    backreference(ins);
    ins->recompute_shape();
}
Paul's avatar
Paul committed
149

Paul's avatar
Paul committed
150
151
152
153
154
155
156
157
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args)
{
    ins->replace(std::move(o), r, std::move(args));
    backreference(ins);
}
Paul's avatar
Paul committed
158

Paul's avatar
Paul committed
159
160
161
162
163
164
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
    op = std::move(o);
    replace(r);
    replace(std::move(args));
}
Paul's avatar
Paul committed
165

Paul's avatar
Paul committed
166
167
168
169
170
void instruction::replace(std::vector<instruction_ref> args)
{
    clear_arguments();
    arguments = std::move(args);
}
Paul's avatar
Paul committed
171

Paul's avatar
Paul committed
172
173
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
Paul's avatar
Paul committed
174
    assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
Paul's avatar
Paul committed
175
176
177
    std::replace(arguments.begin(), arguments.end(), old, new_ins);
    old->remove_output(*this);
}
Paul's avatar
Paul committed
178

Paul's avatar
Paul committed
179
180
181
182
183
184
bool instruction::can_eval() const
{
    if(op.name() == "@literal")
    {
        return true;
    }
Paul's avatar
Paul committed
185
    else if(is_context_free(op))
Paul's avatar
Paul committed
186
    {
Paul's avatar
Paul committed
187
188
        return std::all_of(
            this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
Paul's avatar
Paul committed
189
190
191
192
193
194
195
    }
    else
    {
        return false;
    }
}

Paul's avatar
Paul committed
196
argument instruction::eval(bool check_eval) const
Paul's avatar
Paul committed
197
198
199
200
201
{
    if(op.name() == "@literal")
    {
        return this->get_literal().get_argument();
    }
Paul's avatar
Paul committed
202
    if(is_context_free(op))
Paul's avatar
Paul committed
203
    {
Paul's avatar
Paul committed
204
        if(check_eval and not this->can_eval())
Paul's avatar
Paul committed
205
            return {};
Paul's avatar
Paul committed
206
        std::vector<argument> args;
Paul's avatar
Paul committed
207
208
209
        std::transform(this->inputs().begin(),
                       this->inputs().end(),
                       std::back_inserter(args),
Paul's avatar
Paul committed
210
                       [](auto arg) { return arg->eval(false); });
Paul's avatar
Paul committed
211
212
213
214
215
        return op.compute(result, args);
    }
    return {};
}

Paul's avatar
Paul committed
216
217
void instruction::finalize(context& ctx)
{
Paul's avatar
Paul committed
218
    if(has_finalize(this->op))
Paul's avatar
Paul committed
219
220
221
        this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}

Paul Fultz II's avatar
Paul Fultz II committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
static void debug_name(std::ostream& os, const instruction& ins)
{
    if(ins.name() == "@literal")
    {
        os << "@literal";
        if(ins.get_literal().get_shape().elements() > 10)
            os << "{ ... }";
        else
            os << "{" << ins.get_literal() << "}";
    }
    else
    {
        os << ins.get_operator();
    }
}

void instruction::debug_print() const
{
    debug_name(std::cout, *this);
    std::string delim = "(";
    for(auto arg : this->inputs())
    {
        std::cout << delim;
        debug_name(std::cout, *arg);
        delim = ", ";
    }
    if(not this->inputs().empty())
        std::cout << ")";
    std::cout << " -> " << this->get_shape() << std::endl;
}

Paul's avatar
Paul committed
253
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
Paul's avatar
Paul committed
254
{
Paul's avatar
Paul committed
255
    auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
Paul's avatar
Paul committed
256
257
    if(i < 0)
        return ins;
Paul's avatar
Paul committed
258
    if(shallow)
Paul's avatar
Paul committed
259
        return ins->inputs().at(i);
Paul's avatar
Paul committed
260
261
262
    return get_output_alias(ins->inputs().at(i));
}

Paul's avatar
Paul committed
263
264
265
266
267
268
269
270
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
    return shapes;
}

Paul's avatar
Paul committed
271
272
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
Paul's avatar
Paul committed
273
    return op.compute_shape(to_shapes(args));
Paul's avatar
Paul committed
274
275
}

Paul's avatar
Paul committed
276
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
277
} // namespace migraphx