"src/targets/gpu/device/greater.cpp" did not exist on "0b217041e03162c32c41873cc024401d3ef1caba"
Unverified Commit 2c60e428 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #263 from ROCmSoftwarePlatform/reshape

Re-enable simplify_reshapes
parents cc8605e4 973b496b
......@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
for(auto arg : ins->inputs())
......
......@@ -103,6 +103,13 @@ struct check_shapes
return *this;
}
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this;
}
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
......
......@@ -29,7 +29,7 @@ struct reshape
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this}.has(1).standard();
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
......
......@@ -29,6 +29,7 @@ struct squeeze
std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
......
......@@ -29,6 +29,7 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
......
......@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape",
"contiguous"
"contiguous",
"squeeze",
"unsqueeze"
};
// clang-format on
return contains(names, ins->name());
......@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
if(is_reshaper(ins))
......@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(ins, t->inputs().front());
}
}
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
propagate_constant{},
dead_code_elimination{},
auto_contiguous{},
//simplify_reshapes{},
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
eliminate_concat{concat_gpu_optimization{}},
......
......@@ -393,7 +393,9 @@ struct tf_parser
int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, weights);
// Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights);
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights);
return prog.add_instruction(op, {args[0], new_weights});
}
......
......@@ -1251,22 +1251,6 @@ struct test_contiguous : verify_program<test_contiguous>
}
};
struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto seq = p.add_parameter("seq", s);
std::vector<int64_t> perm{0, 2, 1, 3};
auto tran_seq = p.add_instruction(migraphx::op::transpose{perm}, seq);
std::vector<int64_t> out_shape{0, 0, -1};
p.add_instruction(migraphx::op::reshape{out_shape}, tran_seq);
return p;
}
};
struct test_transpose : verify_program<test_transpose>
{
migraphx::program create_program() const
......
......@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l3);
p.add_instruction(op, l0, l4);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment