Commit 81eca52d authored by Paul's avatar Paul
Browse files

Add replace instruction

parent 16635962
......@@ -5,10 +5,13 @@
#include <rtg/shape.hpp>
#include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <rtg/erase.hpp>
#include <string>
namespace rtg {
shape compute_shape(operation op, std::vector<instruction_ref> args);
struct instruction
{
instruction() {}
......@@ -20,6 +23,60 @@ struct instruction
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
void replace(operation o, shape r, std::vector<instruction_ref> args)
{
op = o;
replace(std::move(r));
replace(std::move(args));
}
void replace(shape r)
{
if(r != result)
{
result = r;
for(auto&& ins:output)
{
ins->replace(compute_shape(ins->op, ins->arguments));
}
}
}
void replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void clear_arguments()
{
for(auto&& arg:arguments)
{
rtg::erase(arg->output, *this);
}
}
friend bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
friend bool operator==(instruction_ref ref, const instruction& i)
{
return i == ref;
}
friend bool operator!=(const instruction& i, instruction_ref ref)
{
return !(i == ref);
}
friend bool operator!=(instruction_ref ref, const instruction& i)
{
return !(i == ref);
}
operation op;
shape result;
std::vector<instruction_ref> output;
......@@ -27,6 +84,22 @@ struct instruction
literal lit;
};
inline void backreference(instruction_ref ref)
{
for(auto&& arg : ref->arguments)
arg->output.push_back(ref);
}
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
return op.compute_shape(shapes);
}
} // namespace rtg
#endif
......@@ -36,6 +36,14 @@ struct program
instruction_ref
insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{
return replace_instruction(ins, op, {args...});
}
instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
......
......@@ -28,17 +28,28 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
shape r = op.compute_shape(shapes);
// TODO: Use move
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, args});
backreference(result);
assert(result->arguments == args);
for(auto&& arg : args)
arg->output.push_back(result);
return result;
}
instruction_ref
program::replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
shape r = compute_shape(op, args);
ins->replace(op, r, args);
backreference(ins);
return ins;
}
instruction_ref program::add_literal(literal l)
{
impl->instructions.emplace_back(std::move(l));
......
......@@ -33,6 +33,35 @@ struct sum_op
}
};
struct minus_op
{
std::string name() const { return "minus"; }
rtg::argument compute(std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2)
RTG_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
RTG_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
RTG_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
RTG_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = rtg::literal{x - y}.get_argument(); });
});
return result;
}
rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
{
if(inputs.size() != 2)
RTG_THROW("Wrong inputs");
return inputs.front();
}
};
void literal_test()
{
rtg::program p;
......@@ -59,8 +88,23 @@ void param_test()
EXPECT(result != rtg::literal{4});
}
void replace_test()
{
rtg::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.replace_instruction(sum, minus_op{}, two, one);
auto result = p.eval({});
EXPECT(result == rtg::literal{1});
EXPECT(result != rtg::literal{3});
}
int main()
{
literal_test();
param_test();
replace_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment