Commit 57ee17f7 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

handle review comments from Umang

-int to size_t for trace
-move env arg to top of simplify_algebra.cpp
-handle overload beter for find_matches
parent 4acdeab6
......@@ -372,9 +372,9 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module
/// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
void find_matches(size_t trace_mod, Mod& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
......@@ -389,12 +389,12 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
[&](auto&& m) {
if(match)
return;
if(trace > 1)
if(trace > 1 or trace_mod > 1)
std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace > 0)
if(trace > 0 or trace_mod > 0)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins);
......@@ -424,59 +424,13 @@ void find_matches(Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(mod, ins, ms...);
}
}
/// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms>
void find_matches(int trace_mod, Mod& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
if(trace > 1 or trace_mod > 1)
std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace > 0 or trace_mod > 0)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins);
}
// If its already invalid dont validate it again
bool invalidated = validate and get_module(mod).validate() != get_module(mod).end();
m.apply(mod, r);
if(validate and not invalidated)
{
auto invalid = get_module(mod).validate();
if(invalid != get_module(mod).end())
{
std::cout << "Invalid program from match: " << get_type_name(m) << std::endl;
std::cout << "Invalid instructions: " << std::endl;
get_module(mod).debug_print(invalid->inputs());
get_module(mod).debug_print(invalid);
}
find_matches(0, mod, ins, ms...);
}
match = true;
},
ms...);
}
/// Find matches in a module
template <class Mod, class... Ms>
void find_matches(int trace_mod, Mod& mod, Ms&&... ms)
void find_matches(size_t trace_mod, Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(get_module(mod)))
{
......
......@@ -39,6 +39,8 @@
#include <migraphx/algorithm.hpp>
#include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -1483,11 +1485,9 @@ struct find_split_transpose
}
};
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
void simplify_algebra::apply(module& m) const
{
int trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times
for(int i = 0; i < 8; i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment