Commit 7cc3243c authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Fix exception thrown when compiling inceptionv4 (#367)

* Fix compiler crash in TF inceptionv4

* Formatting

* Remove else
parent 3962c2ad
...@@ -19,7 +19,9 @@ void rewrite_pooling::apply(program& prog) const ...@@ -19,7 +19,9 @@ void rewrite_pooling::apply(program& prog) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode != "average") if(op.mode != "average")
continue; continue;
......
...@@ -177,8 +177,7 @@ struct find_concat_transpose ...@@ -177,8 +177,7 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::same_input_shapes(), return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(program& p, const match::matcher_result& mr) const
...@@ -194,8 +193,6 @@ struct find_concat_transpose ...@@ -194,8 +193,6 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i); return p.insert_instruction(ins, op::transpose{permutation}, i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
...@@ -207,20 +204,23 @@ struct find_concat_transpose ...@@ -207,20 +204,23 @@ struct find_concat_transpose
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
{ {
auto end = std::prev(p.end()); for(int i = 0; i < 2; i++)
for(auto ins : iterator_for(p))
{ {
if(ins == end and ins->name() == "contiguous") auto end = std::prev(p.end());
continue; for(auto ins : iterator_for(p))
// Skip possible dead instructions {
if(ins->outputs().empty() and ins != end) if(ins == end and ins->name() == "contiguous")
continue; continue;
match::find_matches(p, // Skip possible dead instructions
ins, if(ins->outputs().empty() and ins != end)
find_nop_reshapes{}, continue;
find_reshaper{}, match::find_matches(p,
find_transpose{}, ins,
find_concat_transpose{}); find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
}
} }
} }
......
#include <migraphx/env.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -24,6 +25,8 @@ ...@@ -24,6 +25,8 @@
#pragma clang diagnostic ignored "-Wglobal-constructors" #pragma clang diagnostic ignored "-Wglobal-constructors"
#endif #endif
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_GPU_COMPILE)
// An improved async, that doesn't block // An improved async, that doesn't block
template <class Function> template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f, std::future<typename std::result_of<Function()>::type> detach_async(Function&& f,
...@@ -82,7 +85,7 @@ auto get_hash(const T& x) ...@@ -82,7 +85,7 @@ auto get_hash(const T& x)
return std::hash<T>{}(x); return std::hash<T>{}(x);
} }
void compile_check(migraphx::program& p, const migraphx::target& t) void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false)
{ {
auto name = t.name(); auto name = t.name();
auto s = p.get_shape(); auto s = p.get_shape();
...@@ -93,6 +96,10 @@ void compile_check(migraphx::program& p, const migraphx::target& t) ...@@ -93,6 +96,10 @@ void compile_check(migraphx::program& p, const migraphx::target& t)
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape"); throw std::runtime_error("Compiling program with " + name + " alters its shape");
} }
if(show_trace)
{
std::cout << ss.str() << std::endl;
}
} }
template <class V> template <class V>
...@@ -116,7 +123,7 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -116,7 +123,7 @@ migraphx::argument run_gpu(migraphx::program& p)
V v; V v;
p = v.create_program(); p = v.create_program();
auto_print pp{p, 1}; auto_print pp{p, 1};
compile_check(p, migraphx::gpu::target{}); compile_check(p, migraphx::gpu::target{}, migraphx::enabled(MIGRAPHX_TRACE_GPU_COMPILE{}));
migraphx::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
...@@ -985,6 +992,24 @@ struct test_conv_pooling : verify_program<test_conv_pooling> ...@@ -985,6 +992,24 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
} }
}; };
struct test_concat_pooling : verify_program<test_concat_pooling>
{
migraphx::program create_program() const
{
migraphx::program p;
auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}});
auto transpose = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, input);
auto concat = p.add_instruction(migraphx::op::concat{3}, transpose);
auto concat_t = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, concat);
auto pooling =
p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {8, 8}}, concat_t);
p.add_instruction(migraphx::op::relu{}, pooling);
return p;
}
};
struct test_global_avg_pooling : verify_program<test_global_avg_pooling> struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -306,4 +306,26 @@ TEST_CASE(concat_transpose2) ...@@ -306,4 +306,26 @@ TEST_CASE(concat_transpose2)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
} }
TEST_CASE(concat_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment