Commit e986fc57 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_pad

parents 4fc22f8f 56c43445
...@@ -109,8 +109,12 @@ struct loader ...@@ -109,8 +109,12 @@ struct loader
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true)); ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type, ap(output_type,
{"--cpp"}, {"--cpp"},
ap.help("Print out the program as cpp program."), ap.help("Print out the program as C++ program."),
ap.set_value("cpp")); ap.set_value("cpp"));
ap(output_type,
{"--python", "--py"},
ap.help("Print out the program as python program."),
ap.set_value("py"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json")); ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type, ap(output_type,
{"--text"}, {"--text"},
...@@ -259,7 +263,9 @@ struct loader ...@@ -259,7 +263,9 @@ struct loader
type = "binary"; type = "binary";
} }
if(type == "cpp") if(type == "py")
p.print_py(*os);
else if(type == "cpp")
p.print_cpp(*os); p.print_cpp(*os);
else if(type == "graphviz") else if(type == "graphviz")
p.print_graph(*os, brief); p.print_graph(*os, brief);
......
...@@ -30,23 +30,31 @@ namespace migraphx { ...@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
T generic_read_file(const std::string& filename) T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{ {
std::ifstream is(filename, std::ios::binary | std::ios::ate); std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg(); if(nbytes == 0)
if(size < 1) {
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename); MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg); is.seekg(offset, std::ios::beg);
T buffer(size, 0); T buffer(nbytes, 0);
if(not is.read(&buffer[0], size)) if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
std::vector<char> read_buffer(const std::string& filename) std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{ {
return generic_read_file<std::vector<char>>(filename); return generic_read_file<std::vector<char>>(filename, offset, nbytes);
} }
std::string read_string(const std::string& filename) std::string read_string(const std::string& filename)
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename); std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename); std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size); void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
...@@ -205,6 +205,12 @@ struct module ...@@ -205,6 +205,12 @@ struct module
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, print_cpp(std::ostream& os,
......
...@@ -115,6 +115,7 @@ struct program ...@@ -115,6 +115,7 @@ struct program
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
......
...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name) ...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return to_c_id("x_" + replace_string(name, ":", "_module_")); return to_c_id("x_" + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
os << "migraphx.op(" << enclose_name(op.name());
auto default_values = make_op(op.name()).to_value();
for(auto&& x : v)
{
auto name = x.get_key();
if(default_values[name] == x)
continue;
os << ", " << name << "=" << to_json_string(x.without_key());
}
os << ")";
}
static void print_make_op(std::ostream& os, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
auto v = op.to_value(); auto v = op.to_value();
...@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op) ...@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os << ")"; os << ")";
} }
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{ {
os << "migraphx::shape{migraphx::shape::" << s.type_string(); os << "migraphx::shape{migraphx::shape::" << s.type_string();
...@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}"; os << "}";
} }
std::unordered_map<instruction_ref, std::string>
module::print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ")" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << mname << ".add_parameter(" << enclose_name(name) << ",";
print_py_shape(os, ins->get_shape());
os << ")" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << ".add_return([" << join_strings(input_vars, ", ") << "])"
<< std::endl;
}
else
{
assert(ins->name().front() != '@');
os << mname << ".add_instruction(";
print_py_op(os, ins->get_operator());
os << ", [" << join_strings(input_vars, ", ") << "]";
os << ")" << std::endl;
}
},
names);
return names;
}
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, module::print_cpp(std::ostream& os,
const std::string& mname, const std::string& mname,
...@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os, ...@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return names; return names;
} }
void module::print_py(std::ostream& os) const { this->print_py(os, this->name(), {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); } void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
......
...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty()) auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();
if(not external_data.empty())
{ {
const std::string& data_file = t.external_data().at(0).value(); const std::string& data_file = external_data.at(0).value();
auto raw_buffer = read_buffer(path + "/" + data_file); size_t num_data_fields = external_data.size();
size_t offset = 0;
size_t nbytes = tensor_shape.bytes();
if(num_data_fields > 1) // if offset field is present
{
offset = std::stoul(t.external_data().at(1).value());
}
if(num_data_fields > 2) // if nbytes field is present
{
nbytes = std::stoul(t.external_data().at(2).value());
}
auto raw_buffer = read_buffer(path + "/" + data_file, offset, nbytes);
std::string s(raw_buffer.begin(), raw_buffer.end()); std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
......
...@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const ...@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm->print_graph(os, brief); mm->print_graph(os, brief);
} }
void program::print_py(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_py(os, var_name, names);
os << std::endl;
}
}
void program::print_cpp(std::ostream& os) const void program::print_cpp(std::ostream& os) const
{ {
auto vec_modules = this->get_modules(); auto vec_modules = this->get_modules();
......
external_constant_test:¡
v0"Constant*g
value*[B const_tensorj)
locationexternal_constant_test.weightj
offset48j
length24p external_constant_testb
0

B
\ No newline at end of file
This diff is collapsed.
...@@ -1818,6 +1818,16 @@ migraphx::program create_external_data_prog() ...@@ -1818,6 +1818,16 @@ migraphx::program create_external_data_prog()
return p; return p;
} }
TEST_CASE(external_constant_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{{migraphx::shape::int64_type, {3}}, {0, 1, 2}});
auto prog = optimize_onnx("external_constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(external_data_test) TEST_CASE(external_data_test)
{ {
migraphx::program p = create_external_data_prog(); migraphx::program p = create_external_data_prog();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment