Commit d2a38cd4 authored by Paul's avatar Paul
Browse files

Add simplify reshapes pass

parent fc8ff61f
......@@ -5,6 +5,7 @@ add_library(migraph
generate.cpp
program.cpp
shape.cpp
simplify_reshapes.cpp
)
rocm_clang_tidy_check(migraph)
target_include_directories(migraph PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
......
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
......@@ -65,7 +65,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
// TODO: Should it be an error if the output is empty?
if(ins->output.empty())
{
remove_instruction(ins);
return rep;
}
for(auto&& out : ins->output)
......@@ -80,8 +79,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
// Replacement should not be dead code unless its the last instruction
assert(!rep->output.empty() or rep == std::prev(end()));
assert(ins->valid(begin()));
if(ins->output.empty())
remove_instruction(ins);
assert(rep->valid(begin()));
return rep;
}
......
#include <migraph/simplify_reshapes.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <unordered_set>
namespace migraph {
bool is_reshaper(const std::string& name)
{
static const std::unordered_set<std::string> names = {
"reshape",
"transpose",
// "broadcast",
"contiguous"
};
return contains(names, name);
}
void simplify_reshapes::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(not is_reshaper(ins->op.name()))
continue;
if(ins->output.size() != 1)
continue;
if(is_reshaper(ins->output.front()->op.name()))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->op.name()))
{
assert(!reshapes.back()->arguments.empty());
assert(p.has_instruction(reshapes.back()->arguments.front()));
reshapes.push_back(reshapes.back()->arguments.front());
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start:iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->result == (*start)->result and i != (*start);
});
if (last != reshapes.rend()) {
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second) {
p.replace_instruction(r.first, r.second);
}
}
}
} // namespace migraph
......@@ -4,6 +4,8 @@
#include <migraph/gpu/context.hpp>
#include <migraph/check_context.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp>
namespace migraph {
namespace gpu {
......@@ -14,8 +16,10 @@ std::vector<pass> target::get_passes(migraph::context&) const
return
{
auto_contiguous{},
simplify_reshapes{},
lowering{},
write_literals{},
dead_code_elimination{},
check_context<context>{}
};
// clang-format on
......
#include <migraph/simplify_reshapes.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct simplify_reshapes_target
{
std::string name() const { return "simplify_reshapes"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::simplify_reshapes{}, migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
migraph::literal get_2x2()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2x2_transposed()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}}; }
migraph::literal get_2_broadcasted()
{
return migraph::literal{{migraph::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
}
void double_contig()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraph::contiguous{}, t1);
auto c2 = p.add_instruction(migraph::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == get_2x2());
}
void double_transpose()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraph::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == get_2x2());
}
void double_transpose_contig()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraph::contiguous{}, t1);
auto t2 = p.add_instruction(migraph::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraph::contiguous{}, t2);
p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == get_2x2());
}
void single_transpose()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 3);
auto result = p.eval({});
EXPECT(result != get_2x2());
}
void double_transpose_sin_pass()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::transpose{{1, 0}}, l);
p.add_instruction(migraph::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
// std::cout << p << std::endl;
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
auto result = p.eval({});
EXPECT(result == get_2x2());
}
void single_transpose_sin_pass()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
p.add_instruction(migraph::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result != get_2x2());
}
int main()
{
double_contig();
double_transpose();
double_transpose_contig();
single_transpose();
double_transpose_sin_pass();
single_transpose_sin_pass();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment