"tools/vscode:/vscode.git/clone" did not exist on "21a2bb0b0b7c81da725dfa04b940c4abdc10f4b1"
Commit fa485ae6 authored by Paul's avatar Paul
Browse files

Use lazy match operators so it will still short-circuit

parent 6d56671b
......@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,6 +25,7 @@ struct loader
std::string file_type;
bool is_nhwc = true;
unsigned trim = 0;
bool optimize = false;
void parse(argument_parser& ap)
{
......@@ -26,6 +35,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
}
program load()
......@@ -48,6 +58,19 @@ struct loader
auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end());
}
if (optimize)
migraphx::run_passes(p, {
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
migraphx::dead_code_elimination{},
migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::propagate_constant{},
migraphx::dead_code_elimination{},
migraphx::eliminate_pad{},
migraphx::dead_code_elimination{},
});
return p;
}
};
......
......@@ -190,6 +190,12 @@ auto pop_back_args(Ts&&... xs)
};
}
template<class T>
auto always(T x)
{
return [=](auto&&...) { return x; };
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -240,6 +240,24 @@ void find_matches(program& p, Ms&&... ms)
}
}
struct lazy_and
{
template<class F, class G>
bool operator()(F f, G g) const
{
return f() and g();
}
};
struct lazy_or
{
template<class F, class G>
bool operator()(F f, G g) const
{
return f() or g();
}
};
template <class Op, bool Start, bool Matches>
struct folder
{
......@@ -248,8 +266,11 @@ struct folder
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Op op;
auto matched = [&](auto m) {
return [&]{ return ctx.matched(m, ins); };
};
bool matches = fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) != ctx.not_found());
return op(always(x), matched(y));
})(Start, ms...);
if(matches == Matches)
return ins;
......@@ -265,9 +286,15 @@ struct folder
Op op;
bool matches = Start;
select(start, [&](auto ins) {
matches = op(matches, fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) != ctx.not_found());
})(Start, ms...));
auto matched = [&](auto m) {
return [&]{ return ctx.matched(m, ins); };
};
auto fold_match = [&] {
return fold([&](auto x, auto y) {
return op(always(x), matched(y));
})(Start, ms...);
};
matches = op(always(matches), fold_match);
});
if(matches == Matches)
return start;
......@@ -277,9 +304,9 @@ struct folder
}
};
const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{};
const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{};
const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
const constexpr auto all_of = folder<lazy_and, true, true>{};
const constexpr auto any_of = folder<lazy_or, false, true>{};
const constexpr auto none_of = folder<lazy_or, false, false>{};
inline auto inputs()
{
......
......@@ -162,7 +162,6 @@ struct find_transpose
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
return;
p.debug_print();
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
......@@ -219,8 +218,9 @@ void simplify_reshapes::apply(program& p) const
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
find_transpose{}
// find_concat_transpose{}
);
}
}
......
......@@ -236,6 +236,8 @@ struct find_triadd
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment