Commit 4ba12b4f authored by Paul's avatar Paul
Browse files

Add insert instruction

parent d3886685
...@@ -26,6 +26,14 @@ struct program ...@@ -26,6 +26,14 @@ struct program
return add_instruction(op, {args...}); return add_instruction(op, {args...});
} }
instruction_ref add_instruction(operation op, std::vector<instruction_ref> args); instruction_ref add_instruction(operation op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{
return insert_instruction(ins, op, {args...});
}
instruction_ref insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
template <class... Ts> template <class... Ts>
instruction_ref add_literal(Ts&&... xs) instruction_ref add_literal(Ts&&... xs)
{ {
......
...@@ -19,17 +19,20 @@ program& program::operator=(program&&) noexcept = default; ...@@ -19,17 +19,20 @@ program& program::operator=(program&&) noexcept = default;
program::~program() noexcept = default; program::~program() noexcept = default;
instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args) instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), std::move(op), std::move(args));
}
instruction_ref program::insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args)
{ {
assert(std::all_of( assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) && args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction"); "Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
std::transform( std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref ins) { return ins->result; }); args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
shape r = op.compute_shape(shapes); shape r = op.compute_shape(shapes);
impl->instructions.push_back({op, r, args}); auto result = impl->instructions.insert(ins, {op, r, args});
assert(impl->instructions.back().arguments == args); assert(result->arguments == args);
auto result = std::prev(impl->instructions.end());
for(auto&& arg : args) for(auto&& arg : args)
arg->output.push_back(result); arg->output.push_back(result);
return result; return result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment