instruction.hpp 3.49 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP

#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/erase.hpp>
Paul's avatar
Paul committed
9
#include <string>
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
namespace migraph {
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
14
shape compute_shape(operation op, std::vector<instruction_ref> args);

Paul's avatar
Paul committed
15
16
struct instruction
{
Paul's avatar
Paul committed
17
18
    instruction() {}

Paul's avatar
Paul committed
19
    instruction(operation o, shape r, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
20
        : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
Paul's avatar
Paul committed
21
22
    {
    }
Paul's avatar
Paul committed
23

Paul's avatar
Paul committed
24
    instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
Paul's avatar
Paul committed
25

Paul's avatar
Paul committed
26
27
28
29
30
31
32
33
34
35
36
37
    void replace(operation o, shape r, std::vector<instruction_ref> args)
    {
        op = o;
        replace(std::move(r));
        replace(std::move(args));
    }

    void replace(shape r)
    {
        if(r != result)
        {
            result = r;
Paul's avatar
Paul committed
38
            for(auto&& ins : output)
Paul's avatar
Paul committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
            {
                ins->replace(compute_shape(ins->op, ins->arguments));
            }
        }
    }

    void replace(std::vector<instruction_ref> args)
    {
        clear_arguments();
        arguments = std::move(args);
    }

    void clear_arguments()
    {
Paul's avatar
Paul committed
53
        for(auto&& arg : arguments)
Paul's avatar
Paul committed
54
        {
Paul's avatar
Paul committed
55
            migraph::erase(arg->output, *this);
Paul's avatar
Paul committed
56
        }
Paul's avatar
Paul committed
57
        arguments.clear();
Paul's avatar
Paul committed
58
59
60
61
62
63
64
    }

    friend bool operator==(const instruction& i, instruction_ref ref)
    {
        return std::addressof(i) == std::addressof(*ref);
    }

Paul's avatar
Paul committed
65
    bool valid(instruction_ref start) const
Paul's avatar
Paul committed
66
    {
Paul's avatar
Paul committed
67
        std::vector<shape> shapes(arguments.size());
Paul's avatar
Paul committed
68
69
70
        std::transform(arguments.begin(), arguments.end(), shapes.begin(), [](instruction_ref ins) {
            return ins->result;
        });
Paul's avatar
Paul committed
71
72
73
74
        shape computed;
        try
        {
            computed = op.compute_shape(shapes);
Paul's avatar
Paul committed
75
76
77
78
79
        }
        catch(migraph::exception&)
        {
            return false;
        }
Paul's avatar
Paul committed
80
        return result == computed &&
Paul's avatar
Paul committed
81
               std::all_of(output.begin(),
Paul's avatar
Paul committed
82
83
84
85
86
87
                           output.end(),
                           [&](instruction_ref i) {
                               return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
                                      i->arguments.end();
                           }) &&
               std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
Paul's avatar
Paul committed
88
                   auto self = std::find(i->output.begin(), i->output.end(), *this);
Paul's avatar
Paul committed
89
90
                   return self != i->output.end() &&
                          std::distance(start, i) < std::distance(start, *self);
Paul's avatar
Paul committed
91
               });
Paul's avatar
Paul committed
92
93
    }

Paul's avatar
Paul committed
94
    friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
Paul's avatar
Paul committed
95

Paul's avatar
Paul committed
96
    friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
Paul's avatar
Paul committed
97

Paul's avatar
Paul committed
98
    friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
Paul's avatar
Paul committed
99

Paul's avatar
Paul committed
100
    operation op;
Paul's avatar
Paul committed
101
    shape result;
Paul's avatar
Paul committed
102
    std::vector<instruction_ref> output;
Paul's avatar
Paul committed
103
    std::vector<instruction_ref> arguments;
Paul's avatar
Paul committed
104
    literal lit;
Paul's avatar
Paul committed
105
106
};

Paul's avatar
Paul committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
inline void backreference(instruction_ref ref)
{
    for(auto&& arg : ref->arguments)
        arg->output.push_back(ref);
}

// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args)
{
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
    return op.compute_shape(shapes);
}

Paul's avatar
Paul committed
123
} // namespace migraph
Paul's avatar
Paul committed
124
125

#endif