Commit 23124b09 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Print graph from the driver (#389)

* Print graph from the driver

* Formatting
parent 50e6d5eb
......@@ -7,16 +7,19 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <fstream>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -37,7 +40,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
}
program load()
......@@ -63,6 +66,7 @@ struct loader
if(optimize)
migraphx::run_passes(p,
{
migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
......@@ -126,12 +130,36 @@ struct compiler
struct read : command<read>
{
loader l;
void parse(argument_parser& ap) { l.parse(ap); }
bool graphviz = false;
bool brief = false;
std::string output;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(graphviz,
{"--graphviz", "-g"},
ap.help("Print out a graphviz representation."),
ap.set_value(true));
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output, {"--output", "-o"}, ap.help("Output to file."));
}
void run()
{
auto p = l.load();
std::cout << p << std::endl;
auto* os = &std::cout;
std::ofstream fs;
if(not output.empty())
{
fs.open(output);
os = &fs;
}
if(graphviz)
p.print_graph(*os, brief);
else
*os << p << std::endl;
}
};
......
......@@ -116,7 +116,7 @@ struct program
void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const;
void print_graph(std::ostream& os) const;
void print_graph(std::ostream& os, bool brief = false) const;
void dry_run(parameter_map params) const;
......
......@@ -591,21 +591,26 @@ static std::string enclose_name(const std::string& name)
return '"' + replace_string(name, "\"", "\\\"") + '"';
}
void program::print_graph(std::ostream& os) const
void program::print_graph(std::ostream& os, bool brief) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
print_program(*this, [&](auto ins, const auto& names) {
os << "\t" << enclose_name(names.at(ins))
<< "[label=" << enclose_name(to_string(ins->get_operator())) << "];";
os << std::endl;
std::string label;
if(brief)
label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "];";
os << std::endl;
if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
}
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment