Commit b49d8e66 authored by Paul's avatar Paul
Browse files

Make reshape require standard shape input

parent ca69c190
...@@ -26,13 +26,6 @@ void eliminate_contiguous::apply(program& p) const ...@@ -26,13 +26,6 @@ void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
......
...@@ -29,7 +29,7 @@ struct reshape ...@@ -29,7 +29,7 @@ struct reshape
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1).standard();
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
......
...@@ -29,6 +29,7 @@ struct squeeze ...@@ -29,6 +29,7 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
......
...@@ -29,6 +29,7 @@ struct unsqueeze ...@@ -29,6 +29,7 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
......
...@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins) ...@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"reshape", "reshape",
"contiguous" "contiguous",
"squeeze",
"unsqueeze"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()); return contains(names, ins->name());
...@@ -94,13 +96,6 @@ void simplify_reshapes::apply(program& p) const ...@@ -94,13 +96,6 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
} }
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -50,7 +50,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -50,7 +50,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
constant_propagate{}, constant_propagate{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
//simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
......
...@@ -1209,22 +1209,6 @@ struct test_contiguous : verify_program<test_contiguous> ...@@ -1209,22 +1209,6 @@ struct test_contiguous : verify_program<test_contiguous>
} }
}; };
struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto seq = p.add_parameter("seq", s);
std::vector<int64_t> perm{0, 2, 1, 3};
auto tran_seq = p.add_instruction(migraphx::op::transpose{perm}, seq);
std::vector<int64_t> out_shape{0, 0, -1};
p.add_instruction(migraphx::op::reshape{out_shape}, tran_seq);
return p;
}
};
struct test_transpose : verify_program<test_transpose> struct test_transpose : verify_program<test_transpose>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment