"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8fe2729e0aa884c88189d42bf4fd7400b564d105"
Unverified Commit a0fa3742 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Add trace for SIMPLIFY_ALGEBRA matches (#1838)

* Add trace for SIMPLIFY_ALGEBRA matches

* Fix format

* handle review comments from Umang

-int to size_t for trace
-move env arg to top of simplify_algebra.cpp
-handle overload beter for find_matches

* Rename trace_mod param to trace_pass

More representative naming for what this trace flag does
parent b8898d7e
...@@ -372,9 +372,9 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -372,9 +372,9 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module /// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -389,12 +389,12 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) ...@@ -389,12 +389,12 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
if(trace > 1) if(trace > 1 or trace_pass > 1)
std::cout << "Match: " << get_type_name(m) << std::endl; std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher()); auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end()) if(r.result == get_module(mod).end())
return; return;
if(trace > 0) if(trace > 0 or trace_pass > 0)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins); get_module(mod).debug_print(ins);
...@@ -424,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms) ...@@ -424,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(get_module(mod))) for(auto ins : iterator_for(get_module(mod)))
{ {
find_matches(mod, ins, ms...); find_matches(0, mod, ins, ms...);
}
}
/// Find matches in a pass
template <class Mod, class... Ms>
void find_matches(size_t trace_pass, Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(trace_pass, mod, ins, ms...);
} }
} }
......
...@@ -39,6 +39,8 @@ ...@@ -39,6 +39,8 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <unordered_set> #include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1485,10 +1487,13 @@ struct find_split_transpose ...@@ -1485,10 +1487,13 @@ struct find_split_transpose
void simplify_algebra::apply(module& m) const void simplify_algebra::apply(module& m) const
{ {
size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(m, match::find_matches(trace,
m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{}, find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment