Unverified Commit 3f3885ac authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Parse graph node topologically (#479)



* python api support scalar argument

* clang format

* change parse operator function signature

* clang format

* add parsing the split operator

* clang format

* add parsing split operator

* make squeeze/unsqueeze inputs to standard shape

* add unit tests for the split operator

* clang format

* fix cppcheck error

* clang format

* update tests for multiple program outputs

* clang format

* update the function parse_graph

* clang format

* fixed an unit test

* revert code back

* remove blank line

* refine an error message

* add unit tests for code change

* clang format

* refine an error message
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 1a692f60
......@@ -1834,7 +1834,6 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph)
{
nodes = get_nodes(graph);
for(auto&& f : graph.initializer())
instructions[f.name()] = prog.add_literal(parse_tensor(f));
......@@ -1849,9 +1848,41 @@ struct onnx_parser
instructions[name] = prog.add_parameter(name, s);
}
}
for(auto&& output : graph.output())
for(auto&& node : graph.node())
{
this->parse_node(output.name());
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
if(input.empty())
{
this->parse_undefined(input);
}
if(instructions.count(input) == 0)
{
MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
"\" is unavailable due to unordered nodes!");
}
args.push_back(instructions.at(input));
}
std::vector<instruction_ref> result;
std::size_t output_num = static_cast<std::size_t>(node.output().size());
if(ops.count(node.op_type()) == 0)
{
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
}
else
{
result = ops[node.op_type()]({get_attributes(node), output_num}, args);
}
output_num = std::min<std::size_t>(output_num, result.size());
std::transform(node.output().begin(),
node.output().begin() + output_num,
result.begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); });
}
// Find instructions corresponding to the output
......@@ -1884,54 +1915,6 @@ struct onnx_parser
instructions[name] = ins;
}
void parse_node(const std::string& name)
{
if(name.empty())
MIGRAPHX_THROW("Onnx node must have a name");
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
if(input.empty())
{
this->parse_undefined(input);
}
else if(nodes.count(input) > 0)
{
assert(name != input);
this->parse_node(input);
}
args.push_back(instructions.at(input));
}
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0)
{
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
}
else
{
std::size_t output_num = static_cast<std::size_t>(node.output().size());
result = ops[node.op_type()]({get_attributes(node), output_num}, args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
auto output_num = std::min<std::size_t>(node.output().size(), result.size());
std::transform(node.output().begin(),
node.output().begin() + output_num,
result.begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); });
}
}
}
static attribute_map get_attributes(const onnx::NodeProto& node)
{
std::unordered_map<std::string, onnx::AttributeProto> result;
......@@ -1942,32 +1925,6 @@ struct onnx_parser
return result;
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node())
{
if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output())
{
result[output] = node;
}
}
return result;
}
template <class T>
static literal from_repeated(shape::type_t t, const T& r)
{
......
......@@ -2109,6 +2109,16 @@ def transpose_gather_test():
return ([td, ti, node], [x, i], [y])
@onnx_test
def undefined_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3, 4, 5])
node = onnx.helper.make_node('Identity', inputs=[''], outputs=['1'])
return ([node], [x], [y])
@onnx_test
def unknown_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -1345,10 +1345,10 @@ TEST_CASE(shape_gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}});
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = optimize_onnx("shape_gather_test.onnx");
......@@ -1391,10 +1391,10 @@ TEST_CASE(slice_5arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {1, 1}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}});
auto ret = p.add_instruction(migraphx::op::slice{{-1, -2}, {-5, -3}, {-1, -1}}, l0);
p.add_return({ret});
......@@ -1594,6 +1594,19 @@ TEST_CASE(transpose_gather_test)
EXPECT(p == prog);
}
TEST_CASE(undefined_test)
{
migraphx::program p;
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_instruction(migraphx::op::undefined{});
auto l2 = p.add_instruction(migraphx::op::identity{}, l1);
p.add_return({l2});
auto prog = migraphx::parse_onnx("undefined_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment