Commit 80ffc159 authored by Paul's avatar Paul
Browse files

Add replace_instructions

parent d233807e
......@@ -13,8 +13,10 @@ void auto_contiguous::apply(program& p) const
shape s = ins->result;
if(not s.standard())
{
auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
p.replace_instruction(ins, contiguous{}, prev);
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
p.replace_instructions(ins, ins, std::next(c));
// auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
// p.replace_instruction(ins, contiguous{}, prev);
}
}
}
......
......@@ -39,17 +39,28 @@ struct instruction
for(auto&& ins : output)
{
assert(ins->op.name().front() != '@');
ins->replace(compute_shape(ins->op, ins->arguments));
ins->recompute_shape();
}
}
}
void recompute_shape()
{
replace(compute_shape(op, arguments));
}
void replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
recompute_shape();
}
void clear_arguments()
{
for(auto&& arg : arguments)
......
......@@ -52,6 +52,8 @@ struct program
instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
instruction_ref replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
......
......@@ -17,6 +17,29 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it);
}
template<class Iterator>
struct iterator_range
{
Iterator start;
Iterator last;
Iterator begin() const
{
return start;
}
Iterator end() const
{
return last;
}
};
template<class Iterator>
iterator_range<Iterator> range(Iterator start, Iterator last)
{
return {start, last};
}
} // namespace migraph
#endif
......@@ -55,6 +55,23 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
return ins;
}
instruction_ref program::replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last)
{
auto rep = std::prev(last);
for(auto&& out:ins->output)
{
if(std::find(start, last, out) == last)
{
out->replace_argument(ins, rep);
backreference(out);
}
}
if(ins->output.empty())
return remove_instruction(ins);
return ins;
}
instruction_ref program::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
......
......@@ -14,35 +14,41 @@
#include "test.hpp"
#include "verify.hpp"
struct auto_eval
auto& handlers()
{
migraph::program* p;
migraph::program::parameter_map* m;
migraph::argument result;
auto_eval(migraph::program& pp, migraph::program::parameter_map& pm) : p(&pp), m(&pm) {}
static std::array<std::function<void()>, 2> x = {};
return x;
}
migraph::argument operator()() const { return p->eval(*m); }
struct auto_print
{
migraph::program& p;
int index;
auto_print(migraph::program& pp, int i) : p(pp), index(i)
{
handlers()[index] = [this]{ std::cout << p << std::endl; };
}
~auto_eval()
~auto_print()
{
if(std::uncaught_exception())
std::cout << *p << std::endl;
handlers()[index] = []{};
}
};
template <class V>
migraph::argument run_cpu()
{
V v;
auto p = v.create_program();
auto_print pp{p, 0};
p.compile(migraph::cpu::cpu_target{});
migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::generate_argument(x.second);
}
return auto_eval(p, m)();
return p.eval(m);
}
template <class V>
......@@ -50,6 +56,7 @@ migraph::argument run_gpu()
{
V v;
auto p = v.create_program();
auto_print pp{p, 1};
p.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
......@@ -58,12 +65,23 @@ migraph::argument run_gpu()
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
return migraph::gpu::from_gpu(auto_eval(p, m)());
return migraph::gpu::from_gpu(p.eval(m));
}
template <class V>
void verify_program()
{
std::set_terminate(+[] {
try
{
std::rethrow_exception(std::current_exception());
}
catch(const std::exception& e)
{
std::cout << "what(): " << e.what() << std::endl;
}
for(auto&& handle:handlers()) handle();
});
auto cpu_arg = run_cpu<V>();
auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment