instruction.cpp 12.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
4
5
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
9

10
11
12
13
14
15
template <class T>
auto equal_to(const T& x)
{
    return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}

Paul's avatar
Paul committed
16
17
18
19
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
    : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
Paul's avatar
Paul committed
20

Shucai Xiao's avatar
Shucai Xiao committed
21
22
23
24
25
26
27
28
29
30
31
instruction::instruction(operation o,
                         shape r,
                         std::vector<instruction_ref> args,
                         std::vector<module_ref> modules)
    : op(std::move(o)),
      result(std::move(r)),
      arguments(std::move(args)),
      module_args(std::move(modules))
{
}

Paul's avatar
Paul committed
32
33
34
35
instruction::instruction(literal l)
    : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
Paul's avatar
Paul committed
36

Paul's avatar
Paul committed
37
38
39
void instruction::replace(const shape& r)
{
    if(r != result)
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
42
        result = r;
        for(auto&& ins : output)
Paul's avatar
Paul committed
43
        {
44
45
46
            if(ins->name() == "@return")
                continue;

Paul's avatar
Paul committed
47
48
            assert(ins->name().front() != '@');
            ins->recompute_shape();
Paul's avatar
Paul committed
49
50
        }
    }
Paul's avatar
Paul committed
51
}
Paul's avatar
Paul committed
52

Paul's avatar
Paul committed
53
void instruction::replace(operation o)
Paul's avatar
Paul committed
54
{
55
56
    normalized = false;
    op         = std::move(o);
Paul's avatar
Paul committed
57
58
59
    recompute_shape();
}

Shucai Xiao's avatar
Shucai Xiao committed
60
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
Paul's avatar
Paul committed
61

Paul's avatar
Paul committed
62
63
64
void instruction::clear_arguments()
{
    for(auto&& arg : arguments)
Paul's avatar
Paul committed
65
    {
Paul's avatar
Paul committed
66
        arg->remove_output(*this);
Paul's avatar
Paul committed
67
    }
Paul's avatar
Paul committed
68
    arguments.clear();
Shucai Xiao's avatar
Shucai Xiao committed
69
    module_args.clear();
Paul's avatar
Paul committed
70
71
72
73
74
75
}

bool operator==(const instruction& i, instruction_ref ref)
{
    return std::addressof(i) == std::addressof(*ref);
}
Paul's avatar
Paul committed
76

charlie's avatar
charlie committed
77
78
static void debug_name(std::ostream& os, const instruction& ins);

Shucai Xiao's avatar
Shucai Xiao committed
79
bool instruction::valid(instruction_ref start, bool check_order) const
Paul's avatar
Paul committed
80
{
charlie's avatar
charlie committed
81
82
83
84
85
86
87
88
89
90
91
92
    // Need this lambda because std::distance has undefined behavior when comparing iterators of
    // from different ranges
    auto ins_dist = [](instruction_ref s, instruction_ref e) {
        int dist = 0;
        while((*s) != (*e))
        {
            ++s;
            ++dist;
        }
        return dist;
    };

Paul's avatar
Paul committed
93
94
    return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
               auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
               bool ret  = self != i->outputs().end();
               if(check_order)
               {
charlie's avatar
charlie committed
98
99
                   // arguments for this instruction before this instruction
                   ret = ret and (ins_dist(start, i) < ins_dist(start, *self));
Shucai Xiao's avatar
Shucai Xiao committed
100
101
               }
               return ret;
Paul's avatar
Paul committed
102
103
104
105
106
107
108
           });
}

bool instruction::valid() const
{
    shape computed;
    if(op.name() == "@literal")
Paul's avatar
Paul committed
109
    {
Paul's avatar
Paul committed
110
        computed = lit.get_shape();
Paul's avatar
Paul committed
111
    }
Paul's avatar
Paul committed
112
    else if(op.name() == "@param")
Paul's avatar
Paul committed
113
    {
Paul's avatar
Paul committed
114
        computed = result;
Paul's avatar
Paul committed
115
    }
116
117
118
119
    else if(op.name() == "@return")
    {
        computed = {};
    }
Paul's avatar
Paul committed
120
    else
Paul's avatar
Paul committed
121
    {
Paul's avatar
Paul committed
122
        try
Paul's avatar
Paul committed
123
        {
Shucai Xiao's avatar
Shucai Xiao committed
124
            computed = compute_shape(op, arguments, module_args);
Paul's avatar
Paul committed
125
        }
Paul's avatar
Paul committed
126
        catch(migraphx::exception&)
Paul's avatar
Paul committed
127
        {
Paul's avatar
Paul committed
128
            return false;
Paul's avatar
Paul committed
129
130
        }
    }
131

Shucai Xiao's avatar
Shucai Xiao committed
132
133
    return (result == computed) &&
           std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
Paul's avatar
Paul committed
134
135
136
               return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
           });
}
Paul's avatar
Paul committed
137

Paul's avatar
Paul committed
138
139
140
141
142
143
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
    assert(op.name() == "@literal");
    return lit;
}
Paul's avatar
Paul committed
144

Paul's avatar
Paul committed
145
const operation& instruction::get_operator() const { return op; }
Paul's avatar
Paul committed
146

Paul's avatar
Paul committed
147
std::string instruction::name() const { return op.name(); }
Paul's avatar
Paul committed
148

Paul's avatar
Paul committed
149
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
Paul's avatar
Paul committed
150

Shucai Xiao's avatar
Shucai Xiao committed
151
152
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }

Paul's avatar
Paul committed
153
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
Paul's avatar
Paul committed
154

Paul's avatar
Paul committed
155
156
bool operator==(const instruction& x, const instruction& y)
{
157
158
159
160
161
162
163
    if(not std::equal(x.arguments.begin(),
                      x.arguments.end(),
                      y.arguments.begin(),
                      y.arguments.end(),
                      std::equal_to<instruction_ref>{}))
        return false;
    if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
Paul's avatar
Paul committed
164
165
166
167
168
        return false;
    if(x.name() == "@literal")
        return x.lit == y.lit;
    return true;
}
Paul's avatar
Paul committed
169

Paul's avatar
Paul committed
170
171
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }

Paul's avatar
Paul committed
172
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
173

Paul's avatar
Paul committed
174
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
175

Paul's avatar
Paul committed
176
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
177

Paul's avatar
Paul committed
178
179
void instruction::add_output(instruction_ref ins)
{
180
    if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
Paul's avatar
Paul committed
181
182
        output.push_back(ins);
}
Paul's avatar
Paul committed
183

Paul's avatar
Paul committed
184
185
186
187
188
void instruction::backreference(instruction_ref ref)
{
    for(auto&& arg : ref->inputs())
        arg->add_output(ref);
}
Paul's avatar
Paul committed
189

Paul's avatar
Paul committed
190
191
192
193
194
195
196
197
void instruction::replace_argument(instruction_ref ins,
                                   instruction_ref old,
                                   instruction_ref new_ins)
{
    ins->replace_argument(old, new_ins);
    backreference(ins);
    ins->recompute_shape();
}
Paul's avatar
Paul committed
198

Shucai Xiao's avatar
Shucai Xiao committed
199
200
201
202
203
204
205
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
{
    ins->replace_mod_argument(old, new_mod);
    backreference(ins);
    ins->recompute_shape();
}

Paul's avatar
Paul committed
206
207
208
209
210
211
212
213
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args)
{
    ins->replace(std::move(o), r, std::move(args));
    backreference(ins);
}
Paul's avatar
Paul committed
214

Shucai Xiao's avatar
Shucai Xiao committed
215
216
217
218
219
220
221
222
223
224
void instruction::replace(instruction_ref ins,
                          operation o,
                          const shape& r,
                          std::vector<instruction_ref> args,
                          std::vector<module_ref> module_args)
{
    ins->replace(std::move(o), r, std::move(args), std::move(module_args));
    backreference(ins);
}

Paul's avatar
Paul committed
225
226
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
227
228
    normalized = false;
    op         = std::move(o);
Paul's avatar
Paul committed
229
230
231
    replace(r);
    replace(std::move(args));
}
Paul's avatar
Paul committed
232

Shucai Xiao's avatar
Shucai Xiao committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
void instruction::replace(operation o,
                          const shape& r,
                          std::vector<instruction_ref> args,
                          std::vector<module_ref> mdl_args)
{
    op = std::move(o);
    replace(r);
    replace(std::move(args), std::move(mdl_args));
}

void instruction::replace_refs(
    instruction_ref ins,
    const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
    const std::unordered_map<module_ref, module_ref>& map_mods)
{
    const auto& args = ins->inputs();
    for(const auto& arg : args)
    {
        if(contains(map_insts, arg))
        {
            instruction::replace_argument(ins, arg, map_insts.at(arg));
        }
    }

    const auto& module_args = ins->module_inputs();
    if(module_args.empty())
        return;

    for(const auto& mod : module_args)
    {
        if(contains(map_mods, mod))
        {
            instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
        }
    }
}

Paul's avatar
Paul committed
270
271
272
273
274
void instruction::replace(std::vector<instruction_ref> args)
{
    clear_arguments();
    arguments = std::move(args);
}
Paul's avatar
Paul committed
275

Shucai Xiao's avatar
Shucai Xiao committed
276
277
278
279
280
281
282
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
{
    clear_arguments();
    arguments   = std::move(args);
    module_args = std::move(mdl_args);
}

Paul's avatar
Paul committed
283
284
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
285
286
    assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
    std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
Paul's avatar
Paul committed
287
288
    old->remove_output(*this);
}
Paul's avatar
Paul committed
289

Shucai Xiao's avatar
Shucai Xiao committed
290
291
292
293
294
295
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
{
    assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
    std::replace(module_args.begin(), module_args.end(), old, new_mod);
}

Paul's avatar
Paul committed
296
297
298
299
300
301
bool instruction::can_eval() const
{
    if(op.name() == "@literal")
    {
        return true;
    }
Paul's avatar
Paul committed
302
    else if(is_context_free(op))
Paul's avatar
Paul committed
303
    {
Paul's avatar
Paul committed
304
305
        return std::all_of(
            this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
Paul's avatar
Paul committed
306
307
308
309
310
311
312
    }
    else
    {
        return false;
    }
}

Paul's avatar
Paul committed
313
argument instruction::eval(bool check_eval) const
Paul's avatar
Paul committed
314
315
316
317
318
{
    if(op.name() == "@literal")
    {
        return this->get_literal().get_argument();
    }
Paul's avatar
Paul committed
319
    if(is_context_free(op))
Paul's avatar
Paul committed
320
    {
Paul's avatar
Paul committed
321
        if(check_eval and not this->can_eval())
Paul's avatar
Paul committed
322
            return {};
Paul's avatar
Paul committed
323
        std::vector<argument> args;
Paul's avatar
Paul committed
324
325
326
        std::transform(this->inputs().begin(),
                       this->inputs().end(),
                       std::back_inserter(args),
Paul's avatar
Paul committed
327
                       [](auto arg) { return arg->eval(false); });
328
        return normalized_operator().compute(result, args);
Paul's avatar
Paul committed
329
330
331
332
    }
    return {};
}

Paul's avatar
Paul committed
333
334
void instruction::finalize(context& ctx)
{
Paul's avatar
Paul committed
335
    if(has_finalize(this->op))
Paul's avatar
Paul committed
336
337
338
        this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
void instruction::print(std::ostream& os,
                        instruction_ref ins,
                        const std::unordered_map<instruction_ref, std::string>& names)
{
    os << names.at(ins) << " = ";

    os << ins->get_operator();

    if(ins->name() == "@literal")
    {
        if(ins->get_literal().get_shape().elements() > 10)
            os << "{ ... }";
        else
            os << "{" << ins->get_literal() << "}";
    }

    if(!ins->inputs().empty())
    {
        char delim = '(';
        for(auto&& arg : ins->inputs())
        {
Shucai Xiao's avatar
Shucai Xiao committed
360
361
            std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
            os << delim << arg_name;
362
363
364
365
366
            delim = ',';
        }
        os << ")";
    }

Shucai Xiao's avatar
Shucai Xiao committed
367
368
369
370
371
372
373
374
375
376
377
378
    // print module inputs
    if(!ins->module_inputs().empty())
    {
        std::string delim = ", [";
        for(auto&& mod_arg : ins->module_inputs())
        {
            os << delim << mod_arg->name();
            delim = ", ";
        }
        os << "]";
    }

379
380
381
382
383
    // skip return instruction shape
    if(ins->name() != "@return")
        os << " -> " << ins->get_shape();
}

Paul Fultz II's avatar
Paul Fultz II committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
static void debug_name(std::ostream& os, const instruction& ins)
{
    if(ins.name() == "@literal")
    {
        os << "@literal";
        if(ins.get_literal().get_shape().elements() > 10)
            os << "{ ... }";
        else
            os << "{" << ins.get_literal() << "}";
    }
    else
    {
        os << ins.get_operator();
    }
}

void instruction::debug_print() const
{
    debug_name(std::cout, *this);
    std::string delim = "(";
    for(auto arg : this->inputs())
    {
        std::cout << delim;
        debug_name(std::cout, *arg);
        delim = ", ";
    }
    if(not this->inputs().empty())
        std::cout << ")";
    std::cout << " -> " << this->get_shape() << std::endl;
}

Paul's avatar
Paul committed
415
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
Paul's avatar
Paul committed
416
{
Paul's avatar
Paul committed
417
    auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
Paul's avatar
Paul committed
418
419
    if(i < 0)
        return ins;
Paul's avatar
Paul committed
420
    if(shallow)
Paul's avatar
Paul committed
421
        return ins->inputs().at(i);
Paul's avatar
Paul committed
422
423
424
    return get_output_alias(ins->inputs().at(i));
}

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
void instruction::set_normalized(bool value) { normalized = value; }

bool instruction::is_normalized() const { return normalized; }

bool instruction::need_normalization() const
{
    return this->get_operator().need_normalization() and not normalized;
}

operation instruction::normalized_operator() const
{
    operation o = this->get_operator();
    if(this->need_normalization())
    {
        auto lens = this->inputs().front()->get_shape().lens();
        if(!normalize_attributes(o, lens))
            return this->get_operator();
    }
    return o;
}

Paul's avatar
Paul committed
446
447
448
449
450
451
452
453
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
    return shapes;
}

Paul's avatar
Paul committed
454
455
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
Paul's avatar
Paul committed
456
    return op.compute_shape(to_shapes(args));
Paul's avatar
Paul committed
457
458
}

Shucai Xiao's avatar
Shucai Xiao committed
459
460
461
462
463
464
465
466
467
468
469
470
471
shape compute_shape(const operation& op,
                    const std::vector<instruction_ref>& args,
                    const std::vector<module_ref>& mods)
{
    if(mods.empty())
    {
        return op.compute_shape(to_shapes(args));
    }
    else
    {
        return op.compute_shape(to_shapes(args), mods);
    }
}
472
473
474
475
476
477
478
479
480
481
482
483
484
485

std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs)
{
    shape new_shape;
    try
    {
        new_shape = op.compute_shape(inputs);
    }
    catch(...)
    {
        return {};
    }
    return {new_shape};
}
486
487
488
489
490
491

migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
    return std::addressof(*ins);
}

Paul's avatar
Paul committed
492
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
493
} // namespace migraphx