program.cpp 9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraph/program.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraph/env.hpp>
Paul's avatar
Paul committed
5
#include <iostream>
Paul's avatar
Paul committed
6
#include <sstream>
Paul's avatar
Paul committed
7
8
#include <algorithm>

Paul's avatar
Paul committed
9
namespace migraph {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
12
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)

Paul's avatar
Paul committed
13
14
15
16
struct program_impl
{
    // A list is used to keep references to an instruction stable
    std::list<instruction> instructions;
Paul's avatar
Paul committed
17
    context ctx;
Paul's avatar
Paul committed
18
19
};

Paul's avatar
Paul committed
20
const operation& get_operation(instruction_ref ins) { return ins->op; }
Paul's avatar
Paul committed
21

Paul's avatar
Paul committed
22
program::program() : impl(std::make_unique<program_impl>()) {}
Paul's avatar
Paul committed
23

Paul's avatar
Paul committed
24
program::program(program&&) noexcept = default;
Paul's avatar
Paul committed
25
26
program& program::operator=(program&&) noexcept = default;
program::~program() noexcept                    = default;
Paul's avatar
Paul committed
27

Paul's avatar
Paul committed
28
instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
29
30
31
{
    return insert_instruction(impl->instructions.end(), std::move(op), std::move(args));
}
Paul's avatar
Paul committed
32
33
instruction_ref
program::insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
34
{
Paul's avatar
Paul committed
35
36
37
    assert(std::all_of(
               args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
           "Argument is not an exisiting instruction");
Paul's avatar
Paul committed
38
    assert(not starts_with(op.name(), "@"));
Paul's avatar
Paul committed
39
40
    // TODO: Use move
    shape r     = compute_shape(op, args);
Paul's avatar
Paul committed
41
    auto result = impl->instructions.insert(ins, {op, r, args});
Paul's avatar
Paul committed
42
    backreference(result);
Paul's avatar
Paul committed
43
    assert(result->arguments == args);
Paul's avatar
Paul committed
44
    assert(result->valid(begin()));
Paul's avatar
Paul committed
45
    return result;
Paul's avatar
Paul committed
46
47
}

Paul's avatar
Paul committed
48
49
50
51
52
53
instruction_ref
program::replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args)
{
    assert(std::all_of(
               args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
           "Argument is not an exisiting instruction");
Paul's avatar
Paul committed
54
    assert(not starts_with(op.name(), "@"));
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
    shape r = compute_shape(op, args);
Paul's avatar
Paul committed
57
58
    ins->replace(op, r, args);
    backreference(ins);
Paul's avatar
Paul committed
59
    assert(ins->valid(begin()));
Paul's avatar
Paul committed
60
61
62
    return ins;
}

Paul's avatar
Paul committed
63
instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref rep)
Paul's avatar
Paul committed
64
{
Paul's avatar
Paul committed
65
66
67
68
    assert(has_instruction(ins));
    assert(has_instruction(rep));
    assert(ins != rep);
    // TODO: Should it be an error if the output is empty?
Paul's avatar
Paul committed
69
    if(ins->output.empty())
Paul's avatar
Paul committed
70
71
72
    {
        return rep;
    }
Paul's avatar
Paul committed
73
    for(auto&& out : ins->output)
Paul's avatar
Paul committed
74
    {
Paul's avatar
Paul committed
75
76
        // TODO: Check for possible cycles
        if(out != rep)
Paul's avatar
Paul committed
77
        {
Paul's avatar
Paul committed
78
            replace_argument(out, ins, rep);
Paul's avatar
Paul committed
79
        }
Paul's avatar
Paul committed
80
        assert(out->valid(begin()));
Paul's avatar
Paul committed
81
    }
Paul's avatar
Paul committed
82
83
    // Replacement should not be dead code unless its the last instruction
    assert(!rep->output.empty() or rep == std::prev(end()));
Paul's avatar
Paul committed
84
    assert(ins->valid(begin()));
Paul's avatar
Paul committed
85
    assert(rep->valid(begin()));
Paul's avatar
Paul committed
86
87
88
    return rep;
}

Paul's avatar
Paul committed
89
instruction_ref program::remove_instruction(instruction_ref ins)
Paul's avatar
Paul committed
90
91
92
93
94
95
96
{
    assert(has_instruction(ins));
    assert(ins->output.empty());
    ins->clear_arguments();
    return impl->instructions.erase(ins);
}

97
98
instruction_ref program::remove_instructions(instruction_ref first, instruction_ref last)
{
Paul's avatar
Paul committed
99
100
    if(first == last)
        return first;
Paul's avatar
Paul committed
101
    // TODO: Check every element
102
    assert(has_instruction(first));
Paul's avatar
Paul committed
103
104
    std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
    assert(std::all_of(first, last, [&](instruction& ins) { return ins.output.empty(); }));
105
106
107
108
109
110
111
112
113
    return impl->instructions.erase(first, last);
}

instruction_ref program::move_instruction(instruction_ref src, instruction_ref dst)
{
    impl->instructions.splice(dst, impl->instructions, src);
    return src;
}

Paul's avatar
Paul committed
114
instruction_ref program::add_literal(literal l)
Paul's avatar
Paul committed
115
{
Paul's avatar
Paul committed
116
117
118
119
120
121
122
123
    impl->instructions.emplace_front(std::move(l));
    return impl->instructions.begin();
}

instruction_ref program::add_outline(shape s)
{
    impl->instructions.push_front({builtin::outline{s}, s, {}});
    return impl->instructions.begin();
Paul's avatar
Paul committed
124
125
}

Paul's avatar
Paul committed
126
instruction_ref program::add_parameter(std::string name, shape s)
Paul's avatar
Paul committed
127
{
Paul's avatar
Paul committed
128
129
130
131
    impl->instructions.push_front({builtin::param{std::move(name)}, s, {}});
    return impl->instructions.begin();
}

Paul's avatar
Paul committed
132
shape program::get_parameter_shape(std::string name) const
Paul's avatar
Paul committed
133
134
{
    auto ins = std::find_if(
Paul's avatar
Paul committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
            if(x.op.name() == "@param")
            {
                return any_cast<builtin::param>(x.op).parameter == name;
            }
            else
            {
                return false;
            }
        });
    if(ins != this->end())
        return ins->result;
    else
        return {};
Paul's avatar
Paul committed
149
150
}

Paul's avatar
Paul committed
151
152
153
std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
    std::unordered_map<std::string, shape> result;
Paul's avatar
Paul committed
154
    for(auto&& ins : impl->instructions)
Paul's avatar
Paul committed
155
156
157
    {
        if(ins.op.name() == "@param")
        {
Paul's avatar
Paul committed
158
            auto&& name  = any_cast<builtin::param>(ins.op).parameter;
Paul's avatar
Paul committed
159
160
161
162
163
164
            result[name] = ins.result;
        }
    }
    return result;
}

Paul's avatar
Paul committed
165
bool program::has_instruction(instruction_ref ins) const
Paul's avatar
Paul committed
166
{
Paul's avatar
Paul committed
167
168
169
170
    return std::find_if(
               impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
                   return std::addressof(*ins) == std::addressof(x);
               }) != impl->instructions.end();
Paul's avatar
Paul committed
171
172
}

Paul's avatar
Paul committed
173
174
instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); }
175

Paul's avatar
Paul committed
176
shape program::get_shape() const { return impl->instructions.back().result; }
Paul's avatar
Paul committed
177

Paul's avatar
Paul committed
178
179
instruction_ref program::validate() const
{
Paul's avatar
Paul committed
180
181
    return std::find_if(impl->instructions.begin(),
                        impl->instructions.end(),
Paul's avatar
Paul committed
182
                        [&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
Paul's avatar
Paul committed
183
184
}

Paul's avatar
Paul committed
185
186
void program::compile(const target& t)
{
Paul's avatar
Paul committed
187
    assert(this->validate() == impl->instructions.end());
Paul's avatar
Paul committed
188
    this->impl->ctx = t.get_context();
Paul's avatar
Paul committed
189
    if(enabled(MIGRAPH_TRACE_COMPILE{}))
Paul's avatar
Paul committed
190
191
        std::cout << *this << std::endl << std::endl;
    ;
Paul's avatar
Paul committed
192
    for(auto&& p : t.get_passes(this->impl->ctx))
Paul's avatar
Paul committed
193
    {
Paul's avatar
Paul committed
194
195
        if(enabled(MIGRAPH_TRACE_COMPILE{}))
            std::cout << "Pass: " << p.name() << std::endl;
Paul's avatar
Paul committed
196
        p.apply(*this);
Paul's avatar
Paul committed
197
198
        if(enabled(MIGRAPH_TRACE_COMPILE{}))
            std::cout << *this << std::endl << std::endl;
Paul's avatar
Paul committed
199
#ifndef NDEBUG
Paul's avatar
Paul committed
200
        auto invalid = this->validate();
Paul's avatar
Paul committed
201
202
        if(invalid != impl->instructions.end())
        {
Paul's avatar
Paul committed
203
            auto index = std::distance(impl->instructions.begin(), invalid);
Paul's avatar
Paul committed
204
            MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
Paul's avatar
Paul committed
205
                          std::to_string(index) + ": " + invalid->op.name());
Paul's avatar
Paul committed
206
        }
Paul's avatar
Paul committed
207
208
#endif
    }
Paul's avatar
Paul committed
209
    auto invalid = this->validate();
Paul's avatar
Paul committed
210
211
    if(invalid != impl->instructions.end())
    {
Paul's avatar
Paul committed
212
213
214
        auto index = std::distance(impl->instructions.begin(), invalid);
        MIGRAPH_THROW("Invalid program from compilation at instruction " + std::to_string(index));
    }
Paul's avatar
Paul committed
215
216
}

217
argument program::eval(std::unordered_map<std::string, argument> params) const
Paul's avatar
Paul committed
218
{
Paul's avatar
Paul committed
219
    assert(this->validate() == impl->instructions.end());
Paul's avatar
Paul committed
220
221
    std::unordered_map<const instruction*, argument> results;
    argument result;
Paul's avatar
Paul committed
222
    for(auto& ins : impl->instructions)
Paul's avatar
Paul committed
223
    {
Paul's avatar
Paul committed
224
        if(ins.op.name() == "@literal")
Paul's avatar
Paul committed
225
226
227
        {
            result = ins.lit.get_argument();
        }
Paul's avatar
Paul committed
228
        else if(ins.op.name() == "@param")
Paul's avatar
Paul committed
229
        {
Paul's avatar
Paul committed
230
            result = params.at(any_cast<builtin::param>(ins.op).parameter);
Paul's avatar
Paul committed
231
        }
Paul's avatar
Paul committed
232
233
234
235
        else if(ins.op.name() == "@outline")
        {
            result = argument{ins.result, nullptr};
        }
Paul's avatar
Paul committed
236
237
238
        else
        {
            std::vector<argument> values(ins.arguments.size());
Paul's avatar
Paul committed
239
240
241
            std::transform(ins.arguments.begin(),
                           ins.arguments.end(),
                           values.begin(),
Paul's avatar
Paul committed
242
                           [&](instruction_ref i) { return results.at(std::addressof(*i)); });
Paul's avatar
Paul committed
243
            result = ins.op.compute(this->impl->ctx, ins.result, values);
Paul's avatar
Paul committed
244
245
246
        }
        results.emplace(std::addressof(ins), result);
    }
247
    return result;
Paul's avatar
Paul committed
248
249
}

Paul's avatar
Paul committed
250
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
Paul's avatar
Paul committed
251

Paul's avatar
Paul committed
252
std::ostream& operator<<(std::ostream& os, const program& p)
Paul's avatar
Paul committed
253
254
255
256
{
    std::unordered_map<const instruction*, std::string> names;
    int count = 0;

Paul's avatar
Paul committed
257
    for(auto& ins : p.impl->instructions)
Paul's avatar
Paul committed
258
259
    {
        std::string var_name = "@" + std::to_string(count);
Paul's avatar
Paul committed
260
        if(ins.op.name() == "@param")
Paul's avatar
Paul committed
261
        {
Paul's avatar
Paul committed
262
            var_name = any_cast<builtin::param>(ins.op).parameter;
Paul's avatar
Paul committed
263
264
        }

Paul's avatar
Paul committed
265
        os << var_name << " = ";
Paul's avatar
Paul committed
266

Paul's avatar
Paul committed
267
        os << ins.op;
Paul's avatar
Paul committed
268
269
270

        if(ins.op.name() == "@literal")
        {
Paul's avatar
Paul committed
271
            if(ins.lit.get_shape().elements() > 10)
Paul's avatar
Paul committed
272
                os << "{ ... }";
Paul's avatar
Paul committed
273
            else
Paul's avatar
Paul committed
274
                os << "{" << ins.lit << "}";
Paul's avatar
Paul committed
275
276
277
278
279
        }

        if(!ins.arguments.empty())
        {
            char delim = '(';
Paul's avatar
Paul committed
280
            for(auto&& arg : ins.arguments)
Paul's avatar
Paul committed
281
            {
Paul's avatar
Paul committed
282
283
                assert(p.has_instruction(arg) && "Instruction not found");
                os << delim << names.at(std::addressof(*arg));
Paul's avatar
Paul committed
284
285
                delim = ',';
            }
Paul's avatar
Paul committed
286
            os << ")";
Paul's avatar
Paul committed
287
288
        }

Paul's avatar
Paul committed
289
        os << " -> " << ins.result;
Paul's avatar
Paul committed
290

Paul's avatar
Paul committed
291
        os << std::endl;
Paul's avatar
Paul committed
292
293

        names.emplace(std::addressof(ins), var_name);
Paul's avatar
Paul committed
294
        count++;
Paul's avatar
Paul committed
295
    }
Paul's avatar
Paul committed
296
    return os;
Paul's avatar
Paul committed
297
298
}

Paul's avatar
Paul committed
299
} // namespace migraph