Commit c808514d authored by Khalique's avatar Khalique
Browse files

Merge branch 'tf-transpose' of...

Merge branch 'tf-transpose' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into tf-transpose_ss
parents cb9f01ae 7c8e1979
......@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,6 +25,7 @@ struct loader
std::string file_type;
bool is_nhwc = true;
unsigned trim = 0;
bool optimize = false;
void parse(argument_parser& ap)
{
......@@ -26,6 +35,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
}
program load()
......@@ -48,6 +58,20 @@ struct loader
auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end());
}
if(optimize)
migraphx::run_passes(p,
{
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
migraphx::dead_code_elimination{},
migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::propagate_constant{},
migraphx::dead_code_elimination{},
migraphx::eliminate_pad{},
migraphx::dead_code_elimination{},
});
return p;
}
};
......
......@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs)
};
}
template <class T>
struct always_f
{
T x;
template <class... Ts>
constexpr T operator()(Ts&&...) const
{
return x;
}
};
template <class T>
auto always(T x)
{
return always_f<T>{x};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -240,17 +240,40 @@ void find_matches(program& p, Ms&&... ms)
}
}
struct lazy_and
{
template <class F, class G>
bool operator()(F f, G g) const
{
return f() and g();
}
};
struct lazy_or
{
template <class F, class G>
bool operator()(F f, G g) const
{
return f() or g();
}
};
template <class Op, bool Start, bool Matches>
struct folder
{
template <class... Ms>
static bool fold_match(matcher_context& ctx, instruction_ref ins, Ms... ms)
{
Op op;
auto matched = [&](auto m) { return [&] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Op op;
bool matches = fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) != ctx.not_found());
})(Start, ms...);
bool matches = folder::fold_match(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
......@@ -265,9 +288,8 @@ struct folder
Op op;
bool matches = Start;
select(start, [&](auto ins) {
matches = op(matches, fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) != ctx.not_found());
})(Start, ms...));
auto fm = [&] { return folder::fold_match(ctx, ins, ms...); };
matches = op(always(matches), fm);
});
if(matches == Matches)
return start;
......@@ -277,9 +299,9 @@ struct folder
}
};
const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{};
const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{};
const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
const constexpr auto all_of = folder<lazy_and, true, true>{};
const constexpr auto any_of = folder<lazy_or, false, true>{};
const constexpr auto none_of = folder<lazy_or, false, false>{};
inline auto inputs()
{
......
......@@ -54,11 +54,6 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return reorder_dims(permutation, permutation);
}
bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
......@@ -78,6 +73,11 @@ std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
......@@ -162,7 +162,6 @@ struct find_transpose
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
return;
p.debug_print();
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
......@@ -190,7 +189,7 @@ struct find_concat_transpose
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation);
op.axis = permutation[op.axis];
op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
......@@ -200,7 +199,8 @@ struct find_concat_transpose
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{permutation}, concat);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
};
......
......@@ -232,10 +232,12 @@ struct find_triadd
void apply(program& p, match::matcher_result r) const
{
auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return;
......
......@@ -148,6 +148,20 @@ TEST_CASE(match_arg7)
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg8)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1)
{
migraphx::program p;
......
......@@ -284,4 +284,26 @@ TEST_CASE(concat_transpose1)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment