Commit 88f53aa9 authored by Paul's avatar Paul
Browse files

Formatting

parent 4506cb47
...@@ -409,7 +409,8 @@ struct tf_parser ...@@ -409,7 +409,8 @@ struct tf_parser
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights)); auto new_weights =
prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {args[0], new_weights}); return prog.add_instruction(op, {args[0], new_weights});
} }
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
migraphx::program optimize_tf(const std::string& name, bool is_nhwc) migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{ {
auto prog = migraphx::parse_tf(name, is_nhwc); auto prog = migraphx::parse_tf(name, is_nhwc);
if (is_nhwc) if(is_nhwc)
migraphx::run_passes(prog, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}, migraphx::eliminate_identity{}}); migraphx::run_passes(prog,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
return prog; return prog;
} }
...@@ -189,7 +192,7 @@ TEST_CASE(mean_test) ...@@ -189,7 +192,7 @@ TEST_CASE(mean_test)
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test.pb", false); auto prog = optimize_tf("mean_test.pb", false);
...@@ -247,11 +250,11 @@ TEST_CASE(pack_test) ...@@ -247,11 +250,11 @@ TEST_CASE(pack_test)
TEST_CASE(pack_test_nhwc) TEST_CASE(pack_test_nhwc)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0); auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1); auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2); auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2}; std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment