Commit 8639a349 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add auto contiguous

parent 69715eab
......@@ -30,8 +30,11 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <stdexcept>
#include <system_error>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -101,8 +104,8 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
// auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}),
// conv); m.debug_print(); if(not skip_elim_contiguous)
// c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, conv);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
......@@ -137,13 +140,18 @@ void layout_nhwc::apply(module_pass_manager& mpm) const
{
// std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module());
insert_contiguous(mpm.get_module());
// insert_contiguous(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
// mpm.get_module().debug_print();
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{});
// std::cout << "after layout" << std::endl;
// mpm.get_module().debug_print();
// if(not this->skip_elim_contiguous)
// mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(auto_contiguous{});
mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{});
......
......@@ -39,10 +39,10 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto relu = mm->add_instruction(migraphx::make_op("relu"), conv);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), relu);
mm->add_instruction(migraphx::make_op("relu"), pooling);
std::cout << p << std::endl;
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment