Unverified Commit d32ab85b authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improvements to driver output (#1710)

Use generate_argument instead of generate_literal for python output as generate_literal doesnt exists
Shorten the names for variables from the main module
Use prefix p_ for parameters
Use shorter variable m for main module in python
parent 55f420fb
......@@ -723,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
for(auto ins : iterator_for(*this))
{
std::string var_name;
if(not this->name().empty() and this->name() != "main")
var_name = this->name() + ":";
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
var_name.append(any_cast<builtin::param>(ins->get_operator()).parameter);
}
else
{
var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
var_name.append("@" + std::to_string(count));
}
// count every instruction so index matches loc in the printout program
count++;
......@@ -795,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
static std::string cpp_var_name(const std::string& name)
{
return to_c_id("x_" + replace_string(name, ":", "_module_"));
std::string prefix = "x_";
if(not contains(name, "@"))
prefix = "p_";
return to_c_id(prefix + replace_string(name, ":", "_module_"));
}
static void print_py_op(std::ostream& os, const operation& op)
......@@ -875,7 +878,7 @@ module::print_py(std::ostream& os,
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
os << "migraphx.generate_argument(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
......
......@@ -861,7 +861,9 @@ void program::print_py(std::ostream& os) const
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
std::string var_name = "m";
if(mod->name() != "main")
var_name += mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
......
......@@ -54,15 +54,15 @@ TEST_CASE(basic_graph_test)
EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"main:@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"main:@3\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"main:@4\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"main:@3\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"main:@3\""));
EXPECT(migraphx::contains(test, "\"main:@3\" -> \"main:@4\""));
EXPECT(migraphx::contains(test, "\"main:@0\" -> \"main:@4\""));
EXPECT(migraphx::contains(test, "\"@3\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@4\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"@3\" -> \"@4\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@4\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment