Unverified Commit d32ab85b authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improvements to driver output (#1710)

Use generate_argument instead of generate_literal for python output as generate_literal doesnt exists
Shorten the names for variables from the main module
Use prefix p_ for parameters
Use shorter variable m for main module in python
parent 55f420fb
...@@ -723,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -723,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
std::string var_name; std::string var_name;
if(not this->name().empty() and this->name() != "main")
var_name = this->name() + ":";
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->get_operator()).parameter; var_name.append(any_cast<builtin::param>(ins->get_operator()).parameter);
} }
else else
{ {
var_name = this->name(); var_name.append("@" + std::to_string(count));
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
} }
// count every instruction so index matches loc in the printout program // count every instruction so index matches loc in the printout program
count++; count++;
...@@ -795,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_') ...@@ -795,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
static std::string cpp_var_name(const std::string& name) static std::string cpp_var_name(const std::string& name)
{ {
return to_c_id("x_" + replace_string(name, ":", "_module_")); std::string prefix = "x_";
if(not contains(name, "@"))
prefix = "p_";
return to_c_id(prefix + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op) static void print_py_op(std::ostream& os, const operation& op)
...@@ -875,7 +878,7 @@ module::print_py(std::ostream& os, ...@@ -875,7 +878,7 @@ module::print_py(std::ostream& os,
use_abs = false; use_abs = false;
if(use_abs) if(use_abs)
os << "migraphx.abs_literal("; os << "migraphx.abs_literal(";
os << "migraphx.generate_literal("; os << "migraphx.generate_argument(";
print_py_shape(os, ins->get_shape()); print_py_shape(os, ins->get_shape());
os << ", " << seed << ")"; os << ", " << seed << ")";
if(use_abs) if(use_abs)
......
...@@ -861,7 +861,9 @@ void program::print_py(std::ostream& os) const ...@@ -861,7 +861,9 @@ void program::print_py(std::ostream& os) const
os << "p = migraphx.program()\n"; os << "p = migraphx.program()\n";
for(auto& mod : vec_modules) for(auto& mod : vec_modules)
{ {
std::string var_name = "m" + mod->name(); std::string var_name = "m";
if(mod->name() != "main")
var_name += mod->name();
os << var_name << " = "; os << var_name << " = ";
if(mod->name() == "main") if(mod->name() == "main")
os << "p.get_main_module()"; os << "p.get_main_module()";
......
...@@ -54,15 +54,15 @@ TEST_CASE(basic_graph_test) ...@@ -54,15 +54,15 @@ TEST_CASE(basic_graph_test)
EXPECT(migraphx::contains(test, "digraph")); EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR")); EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"main:@0\"[label=\"@literal\"]")); EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]")); EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]")); EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"main:@3\"[label=\"sum\"]")); EXPECT(migraphx::contains(test, "\"@3\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"main:@4\"[label=\"sum\"]")); EXPECT(migraphx::contains(test, "\"@4\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"main:@3\"")); EXPECT(migraphx::contains(test, "\"x\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"main:@3\"")); EXPECT(migraphx::contains(test, "\"y\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"main:@3\" -> \"main:@4\"")); EXPECT(migraphx::contains(test, "\"@3\" -> \"@4\""));
EXPECT(migraphx::contains(test, "\"main:@0\" -> \"main:@4\"")); EXPECT(migraphx::contains(test, "\"@0\" -> \"@4\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]")); EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment