Commit 23124b09 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Print graph from the driver (#389)

* Print graph from the driver

* Formatting
parent 50e6d5eb
...@@ -7,16 +7,19 @@ ...@@ -7,16 +7,19 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <fstream>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,7 +40,7 @@ struct loader ...@@ -37,7 +40,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true)); ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
} }
program load() program load()
...@@ -63,6 +66,7 @@ struct loader ...@@ -63,6 +66,7 @@ struct loader
if(optimize) if(optimize)
migraphx::run_passes(p, migraphx::run_passes(p,
{ {
migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{}, migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::simplify_algebra{}, migraphx::simplify_algebra{},
...@@ -126,12 +130,36 @@ struct compiler ...@@ -126,12 +130,36 @@ struct compiler
struct read : command<read> struct read : command<read>
{ {
loader l; loader l;
void parse(argument_parser& ap) { l.parse(ap); } bool graphviz = false;
bool brief = false;
std::string output;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(graphviz,
{"--graphviz", "-g"},
ap.help("Print out a graphviz representation."),
ap.set_value(true));
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output, {"--output", "-o"}, ap.help("Output to file."));
}
void run() void run()
{ {
auto p = l.load(); auto p = l.load();
std::cout << p << std::endl;
auto* os = &std::cout;
std::ofstream fs;
if(not output.empty())
{
fs.open(output);
os = &fs;
}
if(graphviz)
p.print_graph(*os, brief);
else
*os << p << std::endl;
} }
}; };
......
...@@ -116,7 +116,7 @@ struct program ...@@ -116,7 +116,7 @@ struct program
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const; void debug_print(const std::vector<instruction_ref>& inss) const;
void print_graph(std::ostream& os) const; void print_graph(std::ostream& os, bool brief = false) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
......
...@@ -591,21 +591,26 @@ static std::string enclose_name(const std::string& name) ...@@ -591,21 +591,26 @@ static std::string enclose_name(const std::string& name)
return '"' + replace_string(name, "\"", "\\\"") + '"'; return '"' + replace_string(name, "\"", "\\\"") + '"';
} }
void program::print_graph(std::ostream& os) const void program::print_graph(std::ostream& os, bool brief) const
{ {
os << "digraph {" << std::endl; os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl; os << "\trankdir=LR;" << std::endl;
print_program(*this, [&](auto ins, const auto& names) { print_program(*this, [&](auto ins, const auto& names) {
os << "\t" << enclose_name(names.at(ins)) std::string label;
<< "[label=" << enclose_name(to_string(ins->get_operator())) << "];"; if(brief)
os << std::endl; label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty()) if(!ins->inputs().empty())
{ {
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins)); os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "];"; if(not brief)
os << std::endl; os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
} }
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment