Unverified Commit 56c43445 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Print program as python (#1490)

* Print python code
parent b8c8d09b
......@@ -109,8 +109,12 @@ struct loader
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type,
{"--cpp"},
ap.help("Print out the program as cpp program."),
ap.help("Print out the program as C++ program."),
ap.set_value("cpp"));
ap(output_type,
{"--python", "--py"},
ap.help("Print out the program as python program."),
ap.set_value("py"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type,
{"--text"},
......@@ -259,7 +263,9 @@ struct loader
type = "binary";
}
if(type == "cpp")
if(type == "py")
p.print_py(*os);
else if(type == "cpp")
p.print_cpp(*os);
else if(type == "graphviz")
p.print_graph(*os, brief);
......
......@@ -205,6 +205,12 @@ struct module
void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os,
......
......@@ -115,6 +115,7 @@ struct program
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const;
......
......@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return to_c_id("x_" + replace_string(name, ":", "_module_"));
}
static void print_py_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
os << "migraphx.op(" << enclose_name(op.name());
auto default_values = make_op(op.name()).to_value();
for(auto&& x : v)
{
auto name = x.get_key();
if(default_values[name] == x)
continue;
os << ", " << name << "=" << to_json_string(x.without_key());
}
os << ")";
}
static void print_make_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
......@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os << ")";
}
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx::shape{migraphx::shape::" << s.type_string();
......@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}";
}
std::unordered_map<instruction_ref, std::string>
module::print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ")" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << mname << ".add_parameter(" << enclose_name(name) << ",";
print_py_shape(os, ins->get_shape());
os << ")" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << ".add_return([" << join_strings(input_vars, ", ") << "])"
<< std::endl;
}
else
{
assert(ins->name().front() != '@');
os << mname << ".add_instruction(";
print_py_op(os, ins->get_operator());
os << ", [" << join_strings(input_vars, ", ") << "]";
os << ")" << std::endl;
}
},
names);
return names;
}
std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os,
const std::string& mname,
......@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return names;
}
void module::print_py(std::ostream& os) const { this->print_py(os, this->name(), {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
......
......@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm->print_graph(os, brief);
}
void program::print_py(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_py(os, var_name, names);
os << std::endl;
}
}
void program::print_cpp(std::ostream& os) const
{
auto vec_modules = this->get_modules();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment