Commit a96fae91 authored by Paul's avatar Paul
Browse files

Add a pass to eliminate contiguous operators

parent f0604d78
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
add_library(migraph add_library(migraph
auto_contiguous.cpp auto_contiguous.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_contiguous.cpp
generate.cpp generate.cpp
program.cpp program.cpp
shape.cpp shape.cpp
......
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
namespace migraph {
bool try_compute_shape(operation op, std::vector<instruction_ref> args)
{
try
{
compute_shape(op, args);
}
catch(...)
{
return false;
}
return true;
}
void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// Make a copy so we can modify it while we iterate
auto args = ins->arguments;
for(auto arg : ins->arguments)
{
// TODO: Pass in names for the operator in the constructor instead
// of using ends_with
if(ends_with(arg->op.name(), "contiguous"))
{
auto new_args = args;
auto prev = arg->arguments.front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args))
{
replace_argument(ins, arg, prev);
}
}
}
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct eliminate_contiguous
{
std::string name() const { return "eliminate_contiguous"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -24,6 +24,7 @@ struct instruction ...@@ -24,6 +24,7 @@ struct instruction
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {} instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
// internal
void replace(operation o, shape r, std::vector<instruction_ref> args) void replace(operation o, shape r, std::vector<instruction_ref> args)
{ {
op = o; op = o;
...@@ -46,12 +47,14 @@ struct instruction ...@@ -46,12 +47,14 @@ struct instruction
void recompute_shape() { replace(compute_shape(op, arguments)); } void recompute_shape() { replace(compute_shape(op, arguments)); }
// internal
void replace(std::vector<instruction_ref> args) void replace(std::vector<instruction_ref> args)
{ {
clear_arguments(); clear_arguments();
arguments = std::move(args); arguments = std::move(args);
} }
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins) void replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace(arguments.begin(), arguments.end(), old, new_ins);
......
...@@ -17,6 +17,12 @@ void copy(Range&& r, Iterator it) ...@@ -17,6 +17,12 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
template <class Range, class T>
void replace(Range&& r, const T& old, const T& new_x)
{
std::replace(r.begin(), r.end(), old, new_x);
}
template <class Iterator> template <class Iterator>
struct iterator_range struct iterator_range
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraph/auto_contiguous.hpp> #include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp> #include <migraph/simplify_reshapes.hpp>
#include <migraph/eliminate_contiguous.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
...@@ -19,6 +20,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -19,6 +20,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
lowering{ctx}, lowering{ctx},
eliminate_contiguous{},
write_literals{}, write_literals{},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{}
......
...@@ -13,23 +13,6 @@ struct contiguous_target ...@@ -13,23 +13,6 @@ struct contiguous_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
migraph::literal get_2x2()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2x2_transposed()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}}; }
migraph::literal get_2_broadcasted()
{
return migraph::literal{{migraph::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
}
void literal_broadcast() void literal_broadcast()
{ {
migraph::program p; migraph::program p;
......
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct eliminate_contiguous_target
{
std::string name() const { return "eliminate_contiguous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::eliminate_contiguous{}, migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
void standard_op()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraph::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
void non_standard_op()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraph::contiguous{}, t);
p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
int main()
{
standard_op();
non_standard_op();
}
...@@ -81,6 +81,30 @@ struct pass_op ...@@ -81,6 +81,30 @@ struct pass_op
} }
}; };
struct pass_standard_op
{
std::string name() const { return "pass"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
for(auto&& input:inputs)
{
if(not input.standard())
throw std::runtime_error("Not standard shape");
}
if(inputs.empty())
return {};
return inputs.front();
}
};
struct nop struct nop
{ {
std::string name() const { return "nop"; } std::string name() const { return "nop"; }
...@@ -92,3 +116,20 @@ struct nop ...@@ -92,3 +116,20 @@ struct nop
migraph::shape compute_shape(std::vector<migraph::shape>) const { return {}; } migraph::shape compute_shape(std::vector<migraph::shape>) const { return {}; }
}; };
migraph::literal get_2x2()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2x2_transposed()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}}; }
migraph::literal get_2_broadcasted()
{
return migraph::literal{{migraph::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
}
...@@ -14,23 +14,6 @@ struct simplify_reshapes_target ...@@ -14,23 +14,6 @@ struct simplify_reshapes_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
migraph::literal get_2x2()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2x2_transposed()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}}; }
migraph::literal get_2_broadcasted()
{
return migraph::literal{{migraph::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
}
void double_contig() void double_contig()
{ {
migraph::program p; migraph::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment