Commit 88f53aa9 authored by Paul's avatar Paul
Browse files

Formatting

parent 4506cb47
......@@ -409,7 +409,8 @@ struct tf_parser
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
auto new_weights =
prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {args[0], new_weights});
}
......
......@@ -14,8 +14,11 @@
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, is_nhwc);
if (is_nhwc)
migraphx::run_passes(prog, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}, migraphx::eliminate_identity{}});
if(is_nhwc)
migraphx::run_passes(prog,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
return prog;
}
......@@ -189,7 +192,7 @@ TEST_CASE(mean_test)
migraphx::op::pooling op;
op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test.pb", false);
......@@ -247,11 +250,11 @@ TEST_CASE(pack_test)
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment