"...resnet50_tensorflow.git" did not exist on "dc4b15e676b86c1627d79c9c0b4bf3793eb8e646"
Commit 57ee17f7 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

handle review comments from Umang

-int to size_t for trace
-move env arg to top of simplify_algebra.cpp
-handle overload beter for find_matches
parent 4acdeab6
...@@ -372,9 +372,9 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -372,9 +372,9 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module /// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) void find_matches(size_t trace_mod, Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -389,12 +389,12 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) ...@@ -389,12 +389,12 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
if(trace > 1) if(trace > 1 or trace_mod > 1)
std::cout << "Match: " << get_type_name(m) << std::endl; std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher()); auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end()) if(r.result == get_module(mod).end())
return; return;
if(trace > 0) if(trace > 0 or trace_mod > 0)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins); get_module(mod).debug_print(ins);
...@@ -424,59 +424,13 @@ void find_matches(Mod& mod, Ms&&... ms) ...@@ -424,59 +424,13 @@ void find_matches(Mod& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(get_module(mod))) for(auto ins : iterator_for(get_module(mod)))
{ {
find_matches(mod, ins, ms...); find_matches(0, mod, ins, ms...);
} }
} }
/// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms>
void find_matches(int trace_mod, Mod& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
if(trace > 1 or trace_mod > 1)
std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace > 0 or trace_mod > 0)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins);
}
// If its already invalid dont validate it again
bool invalidated = validate and get_module(mod).validate() != get_module(mod).end();
m.apply(mod, r);
if(validate and not invalidated)
{
auto invalid = get_module(mod).validate();
if(invalid != get_module(mod).end())
{
std::cout << "Invalid program from match: " << get_type_name(m) << std::endl;
std::cout << "Invalid instructions: " << std::endl;
get_module(mod).debug_print(invalid->inputs());
get_module(mod).debug_print(invalid);
}
}
match = true;
},
ms...);
}
/// Find matches in a module /// Find matches in a module
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(int trace_mod, Mod& mod, Ms&&... ms) void find_matches(size_t trace_mod, Mod& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(get_module(mod))) for(auto ins : iterator_for(get_module(mod)))
{ {
......
...@@ -39,6 +39,8 @@ ...@@ -39,6 +39,8 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <unordered_set> #include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1483,11 +1485,9 @@ struct find_split_transpose ...@@ -1483,11 +1485,9 @@ struct find_split_transpose
} }
}; };
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
void simplify_algebra::apply(module& m) const void simplify_algebra::apply(module& m) const
{ {
int trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{}); size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment