Commit 80ffc159 authored by Paul's avatar Paul
Browse files

Add replace_instructions

parent d233807e
...@@ -13,8 +13,10 @@ void auto_contiguous::apply(program& p) const ...@@ -13,8 +13,10 @@ void auto_contiguous::apply(program& p) const
shape s = ins->result; shape s = ins->result;
if(not s.standard()) if(not s.standard())
{ {
auto prev = p.insert_instruction(ins, ins->op, ins->arguments); auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
p.replace_instruction(ins, contiguous{}, prev); p.replace_instructions(ins, ins, std::next(c));
// auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
// p.replace_instruction(ins, contiguous{}, prev);
} }
} }
} }
......
...@@ -39,17 +39,28 @@ struct instruction ...@@ -39,17 +39,28 @@ struct instruction
for(auto&& ins : output) for(auto&& ins : output)
{ {
assert(ins->op.name().front() != '@'); assert(ins->op.name().front() != '@');
ins->replace(compute_shape(ins->op, ins->arguments)); ins->recompute_shape();
} }
} }
} }
void recompute_shape()
{
replace(compute_shape(op, arguments));
}
void replace(std::vector<instruction_ref> args) void replace(std::vector<instruction_ref> args)
{ {
clear_arguments(); clear_arguments();
arguments = std::move(args); arguments = std::move(args);
} }
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
recompute_shape();
}
void clear_arguments() void clear_arguments()
{ {
for(auto&& arg : arguments) for(auto&& arg : arguments)
......
...@@ -52,6 +52,8 @@ struct program ...@@ -52,6 +52,8 @@ struct program
instruction_ref instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args); replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
instruction_ref replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last);
instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last); instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
......
...@@ -17,6 +17,29 @@ void copy(Range&& r, Iterator it) ...@@ -17,6 +17,29 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
template<class Iterator>
struct iterator_range
{
Iterator start;
Iterator last;
Iterator begin() const
{
return start;
}
Iterator end() const
{
return last;
}
};
template<class Iterator>
iterator_range<Iterator> range(Iterator start, Iterator last)
{
return {start, last};
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -55,6 +55,23 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst ...@@ -55,6 +55,23 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
return ins; return ins;
} }
instruction_ref program::replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last)
{
auto rep = std::prev(last);
for(auto&& out:ins->output)
{
if(std::find(start, last, out) == last)
{
out->replace_argument(ins, rep);
backreference(out);
}
}
if(ins->output.empty())
return remove_instruction(ins);
return ins;
}
instruction_ref program::remove_instruction(instruction_ref ins) instruction_ref program::remove_instruction(instruction_ref ins)
{ {
assert(has_instruction(ins)); assert(has_instruction(ins));
......
...@@ -14,35 +14,41 @@ ...@@ -14,35 +14,41 @@
#include "test.hpp" #include "test.hpp"
#include "verify.hpp" #include "verify.hpp"
struct auto_eval auto& handlers()
{ {
migraph::program* p; static std::array<std::function<void()>, 2> x = {};
migraph::program::parameter_map* m; return x;
migraph::argument result; }
auto_eval(migraph::program& pp, migraph::program::parameter_map& pm) : p(&pp), m(&pm) {}
migraph::argument operator()() const { return p->eval(*m); } struct auto_print
{
migraph::program& p;
int index;
auto_print(migraph::program& pp, int i) : p(pp), index(i)
{
handlers()[index] = [this]{ std::cout << p << std::endl; };
}
~auto_eval() ~auto_print()
{ {
if(std::uncaught_exception()) handlers()[index] = []{};
std::cout << *p << std::endl;
} }
}; };
template <class V> template <class V>
migraph::argument run_cpu() migraph::argument run_cpu()
{ {
V v; V v;
auto p = v.create_program(); auto p = v.create_program();
auto_print pp{p, 0};
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
migraph::program::parameter_map m; migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::generate_argument(x.second); m[x.first] = migraph::generate_argument(x.second);
} }
return auto_eval(p, m)(); return p.eval(m);
} }
template <class V> template <class V>
...@@ -50,6 +56,7 @@ migraph::argument run_gpu() ...@@ -50,6 +56,7 @@ migraph::argument run_gpu()
{ {
V v; V v;
auto p = v.create_program(); auto p = v.create_program();
auto_print pp{p, 1};
p.compile(migraph::gpu::target{}); p.compile(migraph::gpu::target{});
migraph::program::parameter_map m; migraph::program::parameter_map m;
...@@ -58,12 +65,23 @@ migraph::argument run_gpu() ...@@ -58,12 +65,23 @@ migraph::argument run_gpu()
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
} }
return migraph::gpu::from_gpu(auto_eval(p, m)()); return migraph::gpu::from_gpu(p.eval(m));
} }
template <class V> template <class V>
void verify_program() void verify_program()
{ {
std::set_terminate(+[] {
try
{
std::rethrow_exception(std::current_exception());
}
catch(const std::exception& e)
{
std::cout << "what(): " << e.what() << std::endl;
}
for(auto&& handle:handlers()) handle();
});
auto cpu_arg = run_cpu<V>(); auto cpu_arg = run_cpu<V>();
auto gpu_arg = run_gpu<V>(); auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment