Commit 8639a349 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add auto contiguous

parent 69715eab
...@@ -30,8 +30,11 @@ ...@@ -30,8 +30,11 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <stdexcept>
#include <system_error>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -101,8 +104,8 @@ void transform_convolutions(module& m, bool skip_elim_contiguous) ...@@ -101,8 +104,8 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
// auto c = conv; // auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}), // auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}),
// conv); m.debug_print(); if(not skip_elim_contiguous) // conv); m.debug_print(); if(not skip_elim_contiguous)
// c = m.insert_instruction(ins, make_op("contiguous"), conv); auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, conv); m.replace_instruction(ins, c);
} }
} }
...@@ -137,13 +140,18 @@ void layout_nhwc::apply(module_pass_manager& mpm) const ...@@ -137,13 +140,18 @@ void layout_nhwc::apply(module_pass_manager& mpm) const
{ {
// std::unordered_set<instruction_ref> output_layouts = // std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module()); // preserve_output_layout(mpm.get_module());
insert_contiguous(mpm.get_module()); // insert_contiguous(mpm.get_module());
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
// mpm.get_module().debug_print(); // mpm.get_module().debug_print();
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous); transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
// std::cout << "after layout" << std::endl;
// mpm.get_module().debug_print();
// if(not this->skip_elim_contiguous) // if(not this->skip_elim_contiguous)
// mpm.run_pass(eliminate_contiguous{"contiguous"}); mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(auto_contiguous{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts); // remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{}); // mpm.run_pass(dead_code_elimination{});
......
...@@ -39,10 +39,10 @@ struct test_conv_pooling : verify_program<test_conv_pooling> ...@@ -39,10 +39,10 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto relu = mm->add_instruction(migraphx::make_op("relu"), conv);
auto pooling = mm->add_instruction( auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), relu);
mm->add_instruction(migraphx::make_op("relu"), pooling); mm->add_instruction(migraphx::make_op("relu"), pooling);
std::cout << p << std::endl;
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment