"profiler/vscode:/vscode.git/clone" did not exist on "80e05267417f948e4f7e63c0fe807106d9a0c0ef"
program.cpp 3.71 KB
Newer Older
Paul's avatar
Paul committed
1
#include <rtg/program.hpp>
Paul's avatar
Paul committed
2
#include <rtg/stringutils.hpp>
Paul's avatar
Paul committed
3
#include <rtg/instruction.hpp>
Paul's avatar
Paul committed
4
#include <iostream>
Paul's avatar
Paul committed
5
6
7
8
#include <algorithm>

namespace rtg {

Paul's avatar
Paul committed
9
10
11
12
13
14
15
16
17
18
struct program_impl
{
    // A list is used to keep references to an instruction stable
    std::list<instruction> instructions;
};

program::program() 
: impl(std::make_unique<program_impl>())
{}

Paul's avatar
Paul committed
19
20
21
program::program(program&&) noexcept = default;
program& program::operator=(program&&) = default;
program::~program() = default;
Paul's avatar
Paul committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

instruction* program::add_instruction(operation op, std::vector<instruction*> args)
{
    assert(std::all_of(
               args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
           "Argument is not an exisiting instruction");
    std::vector<shape> shapes(args.size());
    std::transform(
        args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
    shape r = op.compute_shape(shapes);
    impl->instructions.push_back({op, r, args});
    assert(impl->instructions.back().arguments == args);
    return std::addressof(impl->instructions.back());
}

instruction* program::add_literal(literal l)
{
    impl->instructions.emplace_back(std::move(l));
    return std::addressof(impl->instructions.back());
}

instruction* program::add_parameter(std::string name, shape s)
{
    impl->instructions.push_back({builtin::param{std::move(name)}, s, {}});
    return std::addressof(impl->instructions.back());
}

bool program::has_instruction(const instruction* ins) const
{
    return std::find_if(impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
               return ins == std::addressof(x);
           }) != impl->instructions.end();
}

Paul's avatar
Paul committed
56
literal program::eval(std::unordered_map<std::string, argument> params) const
Paul's avatar
Paul committed
57
58
59
{
    std::unordered_map<const instruction*, argument> results;
    argument result;
Paul's avatar
Paul committed
60
    for(auto& ins : impl->instructions)
Paul's avatar
Paul committed
61
    {
Paul's avatar
Paul committed
62
        if(ins.op.name() == "@literal")
Paul's avatar
Paul committed
63
64
65
        {
            result = ins.lit.get_argument();
        }
Paul's avatar
Paul committed
66
        else if(starts_with(ins.op.name(), "@param"))
Paul's avatar
Paul committed
67
        {
Paul's avatar
Paul committed
68
            result = params.at(ins.op.name().substr(7));
Paul's avatar
Paul committed
69
        }
Paul's avatar
Paul committed
70
71
72
        else
        {
            std::vector<argument> values(ins.arguments.size());
Paul's avatar
Paul committed
73
74
75
76
            std::transform(ins.arguments.begin(),
                           ins.arguments.end(),
                           values.begin(),
                           [&](instruction* i) { return results.at(i); });
Paul's avatar
Paul committed
77
            result = ins.op.compute(values);
Paul's avatar
Paul committed
78
79
80
        }
        results.emplace(std::addressof(ins), result);
    }
Paul's avatar
Paul committed
81
    return literal{result.get_shape(), result.data()};
Paul's avatar
Paul committed
82
83
}

Paul's avatar
Paul committed
84
85
86
87
88
void program::print() const
{
    std::unordered_map<const instruction*, std::string> names;
    int count = 0;

Paul's avatar
Paul committed
89
    for(auto& ins : impl->instructions)
Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
98
99
100
101
102
    {
        std::string var_name = "@" + std::to_string(count);
        if(starts_with(ins.op.name(), "@param"))
        {
            var_name = ins.op.name().substr(7);
        }

        std::cout << var_name << " = ";

        std::cout << ins.op.name();

        if(ins.op.name() == "@literal")
        {
Paul's avatar
Paul committed
103
            if(ins.lit.get_shape().elements() > 10)
Paul's avatar
Paul committed
104
105
106
                std::cout << "{ ... }";
            else
                std::cout << "{" << ins.lit << "}";
Paul's avatar
Paul committed
107
108
109
110
111
        }

        if(!ins.arguments.empty())
        {
            char delim = '(';
Paul's avatar
Paul committed
112
            for(auto&& arg : ins.arguments)
Paul's avatar
Paul committed
113
            {
Paul's avatar
Paul committed
114
                assert(this->has_instruction(arg) && "Instruction not found");
Paul's avatar
Paul committed
115
116
117
118
119
120
121
122
123
124
125
                std::cout << delim << names.at(arg);
                delim = ',';
            }
            std::cout << ")";
        }

        std::cout << " -> " << ins.result;

        std::cout << std::endl;

        names.emplace(std::addressof(ins), var_name);
Paul's avatar
Paul committed
126
        count++;
Paul's avatar
Paul committed
127
128
129
    }
}

Paul's avatar
Paul committed
130
} // namespace rtg