Commit 02837a8a authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add trace for SIMPLIFY_ALGEBRA matches

parent 193f105d
...@@ -428,6 +428,62 @@ void find_matches(Mod& mod, Ms&&... ms) ...@@ -428,6 +428,62 @@ void find_matches(Mod& mod, Ms&&... ms)
} }
} }
/// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms>
void find_matches(int trace_mod, Mod& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
if(trace > 1 or trace_mod > 1)
std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace > 0 or trace_mod > 0)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins);
}
// If its already invalid dont validate it again
bool invalidated = validate and get_module(mod).validate() != get_module(mod).end();
m.apply(mod, r);
if(validate and not invalidated)
{
auto invalid = get_module(mod).validate();
if(invalid != get_module(mod).end())
{
std::cout << "Invalid program from match: " << get_type_name(m) << std::endl;
std::cout << "Invalid instructions: " << std::endl;
get_module(mod).debug_print(invalid->inputs());
get_module(mod).debug_print(invalid);
}
}
match = true;
},
ms...);
}
/// Find matches in a module
template <class Mod, class... Ms>
void find_matches(int trace_mod, Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(trace_mod, mod, ins, ms...);
}
}
template <class M, class F> template <class M, class F>
struct find_generic_match struct find_generic_match
{ {
......
...@@ -1483,12 +1483,18 @@ struct find_split_transpose ...@@ -1483,12 +1483,18 @@ struct find_split_transpose
} }
}; };
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
void simplify_algebra::apply(module& m) const void simplify_algebra::apply(module& m) const
{ {
int trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(m, match::find_matches(trace,
m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{}, find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment