Commit 36221655 authored by Khalique's avatar Khalique
Browse files

continued progres on testing print_graph

parent 39cc651d
...@@ -7,7 +7,6 @@ int main(int argc, char const* argv[]) ...@@ -7,7 +7,6 @@ int main(int argc, char const* argv[])
{ {
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraphx::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
// std::cout << prog << std::endl; std::cout << prog << std::endl;
prog.print_graph(std::cout);
} }
} }
...@@ -54,43 +54,11 @@ static void print_instruction(std::ostream& os, ...@@ -54,43 +54,11 @@ static void print_instruction(std::ostream& os,
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
} }
static std::string enclose_name(const std::string& name) { return '"' + name + '"'; }
static void print_graph_node(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << "\t";
if(!ins->inputs().empty())
{
char delim = '{';
for(auto&& arg : ins->inputs())
{
os << delim << enclose_name(names.at(arg));
delim = ' ';
}
os << '}';
os << " -> ";
}
os << enclose_name(names.at(ins)) << ";";
// if(ins->name() == "@literal")
// {
// if(ins->get_literal().get_shape().elements() > 10)
// os << "{ ... }";
// else
// os << "{" << ins->get_literal() << "}";
// }
}
template <class F> template <class F>
static void print_program( static void print_program(
std::ostream& os,
const program& p, const program& p,
F annonate, F print_func
std::function<void( )
std::ostream&, instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>
print_func = print_instruction)
{ {
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
int count = 0; int count = 0;
...@@ -111,11 +79,7 @@ static void print_program( ...@@ -111,11 +79,7 @@ static void print_program(
(void)arg; (void)arg;
} }
print_func(os, ins, names); print_func(ins, names);
annonate(ins, names);
os << std::endl;
count++; count++;
} }
...@@ -510,10 +474,12 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -510,10 +474,12 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
print_program(os, *this, [&](auto ins, auto&&) { print_program(*this, [&](auto ins, const auto& names) {
print_instruction(std::cout,ins, names);
double avg = common_average(ins_vec[ins]); double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << ": " << avg << "ms, " << percent << "%"; os << ": " << avg << "ms, " << percent << "%";
os << std::endl;
}); });
os << std::endl; os << std::endl;
...@@ -551,7 +517,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -551,7 +517,7 @@ void program::debug_print(instruction_ref ins) const
return; return;
} }
std::stringstream ss; std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) { print_program(*this, [&](auto x, const auto& names) {
if(x == ins) if(x == ins)
{ {
print_instruction(std::cout, x, names); print_instruction(std::cout, x, names);
...@@ -566,11 +532,29 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const ...@@ -566,11 +532,29 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl; std::cout << std::endl;
} }
static std::string enclose_name(const std::string& name)
{
std::string new_name = name;
return '"' + replace_string(new_name, "\"", "\\\"") + '"';
}
void program::print_graph(std::ostream& os) const void program::print_graph(std::ostream& os) const
{ {
os << "digraph {" << std::endl; os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl; os << "\trankdir=LR;" << std::endl;
print_program(os, *this, [](auto&&...) {}, print_graph_node); print_program(*this, [&](auto ins, const auto& names) {
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(to_string(ins->get_operator())) << "];";
os << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "];";
os << std::endl;
}
}
});
os << "}" << std::endl; os << "}" << std::endl;
} }
...@@ -582,14 +566,21 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const ...@@ -582,14 +566,21 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{ {
print_program(os, *this, [&](auto ins, auto&&) { a(ins); }); print_program(*this, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
a(ins);
os << std::endl;
});
} }
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
{ {
print_program(os, p, [](auto&&...) {}); print_program(p, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
os << std::endl;
});
return os; return os;
} }
......
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
TEST_CASE(basic_graph_test)
{
migraphx::program p = create_program();
std::stringstream ss;
p.print_graph(ss);
std::string test = ss.str();
std::cout << test << std::endl;
EXPECT(test.find("[label=@literal]") != std::string::npos);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment