".github/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2ae055629212693fd42b696606c18d01a2194465"
program.cpp 3.9 KB
Newer Older
Paul's avatar
Paul committed
1
#include <rtg/program.hpp>
Paul's avatar
Paul committed
2
#include <rtg/stringutils.hpp>
Paul's avatar
Paul committed
3
#include <rtg/instruction.hpp>
Paul's avatar
Paul committed
4
#include <iostream>
Paul's avatar
Paul committed
5
6
7
8
#include <algorithm>

namespace rtg {

Paul's avatar
Paul committed
9
10
11
12
13
14
struct program_impl
{
    // A list is used to keep references to an instruction stable
    std::list<instruction> instructions;
};

Paul's avatar
Paul committed
15
program::program() : impl(std::make_unique<program_impl>()) {}
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
program::program(program&&) noexcept = default;
Paul's avatar
Paul committed
18
19
program& program::operator=(program&&) noexcept = default;
program::~program() noexcept                    = default;
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
22
{
Paul's avatar
Paul committed
23
24
25
    assert(std::all_of(
               args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
           "Argument is not an exisiting instruction");
Paul's avatar
Paul committed
26
27
    std::vector<shape> shapes(args.size());
    std::transform(
Paul's avatar
Paul committed
28
        args.begin(), args.end(), shapes.begin(), [](instruction_ref ins) { return ins->result; });
Paul's avatar
Paul committed
29
30
31
    shape r = op.compute_shape(shapes);
    impl->instructions.push_back({op, r, args});
    assert(impl->instructions.back().arguments == args);
Paul's avatar
Paul committed
32
33
34
    auto result = std::prev(impl->instructions.end());
    for(auto&& arg:args) arg->output.push_back(result);
    return result;
Paul's avatar
Paul committed
35
36
}

Paul's avatar
Paul committed
37
instruction_ref program::add_literal(literal l)
Paul's avatar
Paul committed
38
39
{
    impl->instructions.emplace_back(std::move(l));
Paul's avatar
Paul committed
40
    return std::prev(impl->instructions.end());
Paul's avatar
Paul committed
41
42
}

Paul's avatar
Paul committed
43
instruction_ref program::add_parameter(std::string name, shape s)
Paul's avatar
Paul committed
44
45
{
    impl->instructions.push_back({builtin::param{std::move(name)}, s, {}});
Paul's avatar
Paul committed
46
    return std::prev(impl->instructions.end());
Paul's avatar
Paul committed
47
48
}

Paul's avatar
Paul committed
49
bool program::has_instruction(instruction_ref ins) const
Paul's avatar
Paul committed
50
{
Paul's avatar
Paul committed
51
52
53
54
    return std::find_if(
               impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
                   return std::addressof(*ins) == std::addressof(x);
               }) != impl->instructions.end();
Paul's avatar
Paul committed
55
56
}

Paul's avatar
Paul committed
57
literal program::eval(std::unordered_map<std::string, argument> params) const
Paul's avatar
Paul committed
58
59
60
{
    std::unordered_map<const instruction*, argument> results;
    argument result;
Paul's avatar
Paul committed
61
    for(auto& ins : impl->instructions)
Paul's avatar
Paul committed
62
    {
Paul's avatar
Paul committed
63
        if(ins.op.name() == "@literal")
Paul's avatar
Paul committed
64
65
66
        {
            result = ins.lit.get_argument();
        }
Paul's avatar
Paul committed
67
        else if(starts_with(ins.op.name(), "@param"))
Paul's avatar
Paul committed
68
        {
Paul's avatar
Paul committed
69
            result = params.at(ins.op.name().substr(7));
Paul's avatar
Paul committed
70
        }
Paul's avatar
Paul committed
71
72
73
        else
        {
            std::vector<argument> values(ins.arguments.size());
Paul's avatar
Paul committed
74
75
76
            std::transform(ins.arguments.begin(),
                           ins.arguments.end(),
                           values.begin(),
Paul's avatar
Paul committed
77
                           [&](instruction_ref i) { return results.at(std::addressof(*i)); });
Paul's avatar
Paul committed
78
            result = ins.op.compute(values);
Paul's avatar
Paul committed
79
80
81
        }
        results.emplace(std::addressof(ins), result);
    }
Paul's avatar
Paul committed
82
    return literal{result.get_shape(), result.data()};
Paul's avatar
Paul committed
83
84
}

Paul's avatar
Paul committed
85
86
87
88
89
void program::print() const
{
    std::unordered_map<const instruction*, std::string> names;
    int count = 0;

Paul's avatar
Paul committed
90
    for(auto& ins : impl->instructions)
Paul's avatar
Paul committed
91
92
93
94
95
96
97
98
99
100
101
102
103
    {
        std::string var_name = "@" + std::to_string(count);
        if(starts_with(ins.op.name(), "@param"))
        {
            var_name = ins.op.name().substr(7);
        }

        std::cout << var_name << " = ";

        std::cout << ins.op.name();

        if(ins.op.name() == "@literal")
        {
Paul's avatar
Paul committed
104
            if(ins.lit.get_shape().elements() > 10)
Paul's avatar
Paul committed
105
106
107
                std::cout << "{ ... }";
            else
                std::cout << "{" << ins.lit << "}";
Paul's avatar
Paul committed
108
109
110
111
112
        }

        if(!ins.arguments.empty())
        {
            char delim = '(';
Paul's avatar
Paul committed
113
            for(auto&& arg : ins.arguments)
Paul's avatar
Paul committed
114
            {
Paul's avatar
Paul committed
115
                assert(this->has_instruction(arg) && "Instruction not found");
Paul's avatar
Paul committed
116
                std::cout << delim << names.at(std::addressof(*arg));
Paul's avatar
Paul committed
117
118
119
120
121
122
123
124
125
126
                delim = ',';
            }
            std::cout << ")";
        }

        std::cout << " -> " << ins.result;

        std::cout << std::endl;

        names.emplace(std::addressof(ins), var_name);
Paul's avatar
Paul committed
127
        count++;
Paul's avatar
Paul committed
128
129
130
    }
}

Paul's avatar
Paul committed
131
} // namespace rtg