instruction.cpp 12.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
4
5
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
9

10
11
12
13
14
15
template <class T>
auto equal_to(const T& x)
{
    return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}

Paul's avatar
Paul committed
16
17
18
19
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
    : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
Paul's avatar
Paul committed
20

Shucai Xiao's avatar
Shucai Xiao committed
21
22
23
24
25
26
27
28
29
30
31
instruction::instruction(operation o,
                         shape r,
                         std::vector<instruction_ref> args,
                         std::vector<module_ref> modules)
    : op(std::move(o)),
      result(std::move(r)),
      arguments(std::move(args)),
      module_args(std::move(modules))
{
}

Paul's avatar
Paul committed
32
33
34
35
instruction::instruction(literal l)
    : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
Paul's avatar
Paul committed
36

37
void instruction::replace(const shape& r, bool stop)
Paul's avatar
Paul committed
38
39
{
    if(r != result)
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
        result = r;
Shucai Xiao's avatar
Shucai Xiao committed
42
43
        if(stop and not r.standard())
            return;
44

Paul's avatar
Paul committed
45
        for(auto&& ins : output)
Paul's avatar
Paul committed
46
        {
47
48
49
            if(ins->name() == "@return")
                continue;

Paul's avatar
Paul committed
50
            assert(ins->name().front() != '@');
51
            ins->recompute_shape(stop);
Paul's avatar
Paul committed
52
53
        }
    }
Paul's avatar
Paul committed
54
}
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
void instruction::replace(operation o)
Paul's avatar
Paul committed
57
{
58
59
    normalized = false;
    op         = std::move(o);
Paul's avatar
Paul committed
60
61
62
    recompute_shape();
}

Shucai Xiao's avatar
Shucai Xiao committed
63
void instruction::recompute_shape(bool non_std_stop)
64
{
Shucai Xiao's avatar
Shucai Xiao committed
65
    replace(compute_shape(op, arguments, module_args), non_std_stop);
66
}
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
69
70
void instruction::clear_arguments()
{
    for(auto&& arg : arguments)
Paul's avatar
Paul committed
71
    {
Paul's avatar
Paul committed
72
        arg->remove_output(*this);
Paul's avatar
Paul committed
73
    }
Paul's avatar
Paul committed
74
    arguments.clear();
Shucai Xiao's avatar
Shucai Xiao committed
75
    module_args.clear();
Paul's avatar
Paul committed
76
77
78
79
80
81
}

bool operator==(const instruction& i, instruction_ref ref)
{
    return std::addressof(i) == std::addressof(*ref);
}
Paul's avatar
Paul committed
82

Shucai Xiao's avatar
Shucai Xiao committed
83
bool instruction::valid(instruction_ref start, bool check_order) const
Paul's avatar
Paul committed
84
85
86
{
    return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
               auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
Shucai Xiao's avatar
Shucai Xiao committed
87
88
89
90
91
92
               bool ret  = self != i->outputs().end();
               if(check_order)
               {
                   ret = ret and (std::distance(start, i) < std::distance(start, *self));
               }
               return ret;
Paul's avatar
Paul committed
93
94
95
96
97
98
99
           });
}

bool instruction::valid() const
{
    shape computed;
    if(op.name() == "@literal")
Paul's avatar
Paul committed
100
    {
Paul's avatar
Paul committed
101
        computed = lit.get_shape();
Paul's avatar
Paul committed
102
    }
Paul's avatar
Paul committed
103
    else if(op.name() == "@param")
Paul's avatar
Paul committed
104
    {
Paul's avatar
Paul committed
105
        computed = result;
Paul's avatar
Paul committed
106
    }
107
108
109
110
    else if(op.name() == "@return")
    {
        computed = {};
    }
Paul's avatar
Paul committed
111
    else
Paul's avatar
Paul committed
112
    {
Paul's avatar
Paul committed
113
        try
Paul's avatar
Paul committed
114
        {
Shucai Xiao's avatar
Shucai Xiao committed
115
            computed = compute_shape(op, arguments, module_args);
Paul's avatar
Paul committed
116
        }
Paul's avatar
Paul committed
117
        catch(migraphx::exception&)
Paul's avatar
Paul committed
118
        {
Paul's avatar
Paul committed
119
            return false;
Paul's avatar
Paul committed
120
121
        }
    }
122

Shucai Xiao's avatar
Shucai Xiao committed
123
124
    return (result == computed) &&
           std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
Paul's avatar
Paul committed
125
126
127
               return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
           });
}
Paul's avatar
Paul committed
128

Paul's avatar
Paul committed
129
130
131
132
133
134
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
    assert(op.name() == "@literal");
    return lit;
}
Paul's avatar
Paul committed
135

Paul's avatar
Paul committed
136
const operation& instruction::get_operator() const { return op; }
Paul's avatar
Paul committed
137

Paul's avatar
Paul committed
138
std::string instruction::name() const { return op.name(); }
Paul's avatar
Paul committed
139

Paul's avatar
Paul committed
140
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
Paul's avatar
Paul committed
141

Shucai Xiao's avatar
Shucai Xiao committed
142
143
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }

Paul's avatar
Paul committed
144
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
Paul's avatar
Paul committed
145

Paul's avatar
Paul committed
146
147
bool operator==(const instruction& x, const instruction& y)
{
148
149
150
151
152
153
154
    if(not std::equal(x.arguments.begin(),
                      x.arguments.end(),
                      y.arguments.begin(),
                      y.arguments.end(),
                      std::equal_to<instruction_ref>{}))
        return false;
    if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
Paul's avatar
Paul committed
155
156
157
158
159
        return false;
    if(x.name() == "@literal")
        return x.lit == y.lit;
    return true;
}
Paul's avatar
Paul committed
160

Paul's avatar
Paul committed
161
162
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }

Paul's avatar
Paul committed
163
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
164

Paul's avatar
Paul committed
165
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
166

Paul's avatar
Paul committed
167
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
168

Paul's avatar
Paul committed
169
170
void instruction::add_output(instruction_ref ins)
{
171
    if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
Paul's avatar
Paul committed
172
173
        output.push_back(ins);
}
Paul's avatar
Paul committed
174

Paul's avatar
Paul committed
175
176
177
178
179
void instruction::backreference(instruction_ref ref)
{
    for(auto&& arg : ref->inputs())
        arg->add_output(ref);
}
Paul's avatar
Paul committed
180

Paul's avatar
Paul committed
181
182
void instruction::replace_argument(instruction_ref ins,
                                   instruction_ref old,
183
184
                                   instruction_ref new_ins,
                                   bool stop)
Paul's avatar
Paul committed
185
186
187
{
    ins->replace_argument(old, new_ins);
    backreference(ins);
188
    ins->recompute_shape(stop);
Paul's avatar
Paul committed
189
}
Paul's avatar
Paul committed
190

Shucai Xiao's avatar
Shucai Xiao committed
191
192
193
194
195
196
197
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
{
    ins->replace_mod_argument(old, new_mod);
    backreference(ins);
    ins->recompute_shape();
}

Paul's avatar
Paul committed
198
199
200
201
202
203
204
205
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args)
{
    ins->replace(std::move(o), r, std::move(args));
    backreference(ins);
}
Paul's avatar
Paul committed
206

Shucai Xiao's avatar
Shucai Xiao committed
207
208
209
210
211
212
213
214
215
216
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args,
                          std::vector<module_ref> module_args)
{
    ins->replace(std::move(o), r, std::move(args), std::move(module_args));
    backreference(ins);
}

Paul's avatar
Paul committed
217
218
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
219
220
    normalized = false;
    op         = std::move(o);
Paul's avatar
Paul committed
221
222
223
    replace(r);
    replace(std::move(args));
}
Paul's avatar
Paul committed
224

Shucai Xiao's avatar
Shucai Xiao committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
void instruction::replace(operation o,
                          const shape& r,
                          std::vector<instruction_ref> args,
                          std::vector<module_ref> mdl_args)
{
    op = std::move(o);
    replace(r);
    replace(std::move(args), std::move(mdl_args));
}

void instruction::replace_refs(
    instruction_ref ins,
    const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
    const std::unordered_map<module_ref, module_ref>& map_mods)
{
    const auto& args = ins->inputs();
    for(const auto& arg : args)
    {
        if(contains(map_insts, arg))
        {
            instruction::replace_argument(ins, arg, map_insts.at(arg));
        }
    }

    const auto& module_args = ins->module_inputs();
    if(module_args.empty())
        return;

    for(const auto& mod : module_args)
    {
        if(contains(map_mods, mod))
        {
            instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
        }
    }
}

Paul's avatar
Paul committed
262
263
264
265
266
void instruction::replace(std::vector<instruction_ref> args)
{
    clear_arguments();
    arguments = std::move(args);
}
Paul's avatar
Paul committed
267

Shucai Xiao's avatar
Shucai Xiao committed
268
269
270
271
272
273
274
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
{
    clear_arguments();
    arguments   = std::move(args);
    module_args = std::move(mdl_args);
}

Paul's avatar
Paul committed
275
276
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
277
278
    assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
    std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
Paul's avatar
Paul committed
279
280
    old->remove_output(*this);
}
Paul's avatar
Paul committed
281

Shucai Xiao's avatar
Shucai Xiao committed
282
283
284
285
286
287
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
{
    assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
    std::replace(module_args.begin(), module_args.end(), old, new_mod);
}

Paul's avatar
Paul committed
288
289
290
291
292
293
bool instruction::can_eval() const
{
    if(op.name() == "@literal")
    {
        return true;
    }
Paul's avatar
Paul committed
294
    else if(is_context_free(op))
Paul's avatar
Paul committed
295
    {
Paul's avatar
Paul committed
296
297
        return std::all_of(
            this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
Paul's avatar
Paul committed
298
299
300
301
302
303
304
    }
    else
    {
        return false;
    }
}

Paul's avatar
Paul committed
305
argument instruction::eval(bool check_eval) const
Paul's avatar
Paul committed
306
307
308
309
310
{
    if(op.name() == "@literal")
    {
        return this->get_literal().get_argument();
    }
Paul's avatar
Paul committed
311
    if(is_context_free(op))
Paul's avatar
Paul committed
312
    {
Paul's avatar
Paul committed
313
        if(check_eval and not this->can_eval())
Paul's avatar
Paul committed
314
            return {};
Paul's avatar
Paul committed
315
        std::vector<argument> args;
Paul's avatar
Paul committed
316
317
318
        std::transform(this->inputs().begin(),
                       this->inputs().end(),
                       std::back_inserter(args),
Paul's avatar
Paul committed
319
                       [](auto arg) { return arg->eval(false); });
320
        return normalized_operator().compute(result, args);
Paul's avatar
Paul committed
321
322
323
324
    }
    return {};
}

Paul's avatar
Paul committed
325
326
void instruction::finalize(context& ctx)
{
Paul's avatar
Paul committed
327
    if(has_finalize(this->op))
Paul's avatar
Paul committed
328
329
330
        this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
void instruction::print(std::ostream& os,
                        instruction_ref ins,
                        const std::unordered_map<instruction_ref, std::string>& names)
{
    os << names.at(ins) << " = ";

    os << ins->get_operator();

    if(ins->name() == "@literal")
    {
        if(ins->get_literal().get_shape().elements() > 10)
            os << "{ ... }";
        else
            os << "{" << ins->get_literal() << "}";
    }

    if(!ins->inputs().empty())
    {
        char delim = '(';
        for(auto&& arg : ins->inputs())
        {
Shucai Xiao's avatar
Shucai Xiao committed
352
353
            std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
            os << delim << arg_name;
354
355
356
357
358
            delim = ',';
        }
        os << ")";
    }

Shucai Xiao's avatar
Shucai Xiao committed
359
360
361
362
363
364
365
366
367
368
369
370
    // print module inputs
    if(!ins->module_inputs().empty())
    {
        std::string delim = ", [";
        for(auto&& mod_arg : ins->module_inputs())
        {
            os << delim << mod_arg->name();
            delim = ", ";
        }
        os << "]";
    }

371
372
373
374
375
    // skip return instruction shape
    if(ins->name() != "@return")
        os << " -> " << ins->get_shape();
}

Paul Fultz II's avatar
Paul Fultz II committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
static void debug_name(std::ostream& os, const instruction& ins)
{
    if(ins.name() == "@literal")
    {
        os << "@literal";
        if(ins.get_literal().get_shape().elements() > 10)
            os << "{ ... }";
        else
            os << "{" << ins.get_literal() << "}";
    }
    else
    {
        os << ins.get_operator();
    }
}

void instruction::debug_print() const
{
    debug_name(std::cout, *this);
    std::string delim = "(";
    for(auto arg : this->inputs())
    {
        std::cout << delim;
        debug_name(std::cout, *arg);
        delim = ", ";
    }
    if(not this->inputs().empty())
        std::cout << ")";
    std::cout << " -> " << this->get_shape() << std::endl;
}

Paul's avatar
Paul committed
407
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
Paul's avatar
Paul committed
408
{
Paul's avatar
Paul committed
409
    auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
Paul's avatar
Paul committed
410
411
    if(i < 0)
        return ins;
Paul's avatar
Paul committed
412
    if(shallow)
Paul's avatar
Paul committed
413
        return ins->inputs().at(i);
Paul's avatar
Paul committed
414
415
416
    return get_output_alias(ins->inputs().at(i));
}

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
void instruction::set_normalized(bool value) { normalized = value; }

bool instruction::is_normalized() const { return normalized; }

bool instruction::need_normalization() const
{
    return this->get_operator().need_normalization() and not normalized;
}

operation instruction::normalized_operator() const
{
    operation o = this->get_operator();
    if(this->need_normalization())
    {
        auto lens = this->inputs().front()->get_shape().lens();
        if(!normalize_attributes(o, lens))
            return this->get_operator();
    }
    return o;
}

Paul's avatar
Paul committed
438
439
440
441
442
443
444
445
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
    return shapes;
}

Paul's avatar
Paul committed
446
447
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
Paul's avatar
Paul committed
448
    return op.compute_shape(to_shapes(args));
Paul's avatar
Paul committed
449
450
}

Shucai Xiao's avatar
Shucai Xiao committed
451
452
453
454
455
456
457
458
459
460
461
462
463
shape compute_shape(const operation& op,
                    const std::vector<instruction_ref>& args,
                    const std::vector<module_ref>& mods)
{
    if(mods.empty())
    {
        return op.compute_shape(to_shapes(args));
    }
    else
    {
        return op.compute_shape(to_shapes(args), mods);
    }
}
464
465
466
467
468
469
470
471
472
473
474
475
476
477

std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs)
{
    shape new_shape;
    try
    {
        new_shape = op.compute_shape(inputs);
    }
    catch(...)
    {
        return {};
    }
    return {new_shape};
}
478
479
480
481
482
483

migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
    return std::addressof(*ins);
}

Paul's avatar
Paul committed
484
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
485
} // namespace migraphx