"mmdet/datasets/data_engine.py" did not exist on "108fc9e1e03ffa6737d1a31ef76fd13bb0cd3cc9"
Commit 3b64f602 authored by Paul's avatar Paul
Browse files

Fix tf test

parent ea2f0cf4
......@@ -266,7 +266,7 @@ struct folder
bool matches = Start;
select(start, [&](auto ins) {
matches = op(matches, fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) == ctx.not_found());
return op(x, y.match(ctx, ins) != ctx.not_found());
})(Start, ms...));
});
if(matches == Matches)
......@@ -310,7 +310,7 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
......@@ -413,6 +413,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
template<class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if (i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
}
template<class... Ms>
auto same_shape(Ms... ms)
{
return all_of(same_shape(ms)...);
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -122,6 +122,25 @@ struct find_reshaper
}
};
struct find_nop_reshapes
{
auto matcher() const
{
auto reshapes = reshaper_names();
reshapes.insert("transpose");
reshapes.insert("slice");
return match::name(reshapes)(
match::same_shape(match::arg(0))
);
}
void apply(program& p, match::matcher_result mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
}
};
struct find_transpose
{
auto matcher() const
......@@ -145,6 +164,7 @@ struct find_transpose
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
return;
p.debug_print();
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
......@@ -160,7 +180,7 @@ struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_shapes(),
return match::name("concat")(match::same_input_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
......@@ -168,19 +188,21 @@ struct find_concat_transpose
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutaion = invert_permutation(permutation);
op.axis = ipermutaion[op.axis];
auto ipermutation = invert_permutation(permutation);
op.axis = permutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if (i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
auto t = p.insert_instruction(ins, op::transpose{permutation}, concat);
p.replace_instruction(ins, t);
}
};
......@@ -195,7 +217,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p, ins, find_reshaper{}, find_transpose{}, find_concat_transpose{});
match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{}, find_transpose{}, find_concat_transpose{});
}
}
......
......@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed)
TEST_CASE(test_shape_transposed1)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
......@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1, 2}, {2, 2, 2, 2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -213,4 +214,78 @@ TEST_CASE(transpose_partial3)
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
TEST_CASE(nop_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(nop_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
TEST_CASE(nop_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(concat_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat = std::find_if(p.begin(), p.end(), [](auto ins) {
return ins.name() == "concat";
});
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment