"docs/EN/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "d996a81cc5a01d630624557479725aeb29a6ab8b"
Commit c6078c1e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into gather_operator

parents e344f80d 2d80965f
......@@ -24,7 +24,8 @@ struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
......@@ -88,6 +89,15 @@ struct onnx_parser
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
......@@ -95,7 +105,7 @@ struct onnx_parser
template <class F>
void add_mem_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
......@@ -103,7 +113,7 @@ struct onnx_parser
template <class T>
void add_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast") and contains(attributes, "axis"))
......@@ -172,7 +182,7 @@ struct onnx_parser
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
......@@ -180,7 +190,7 @@ struct onnx_parser
template <class T>
void add_variadic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
......@@ -643,7 +653,7 @@ struct onnx_parser
}
else
{
throw std::runtime_error("Failed reading");
MIGRAPHX_THROW("Failed reading onnx file.");
}
}
......@@ -673,7 +683,7 @@ struct onnx_parser
}
for(auto&& p : nodes)
{
this->parse_node(get_name(p.second));
this->parse_node(p.first);
}
}
......@@ -689,23 +699,37 @@ struct onnx_parser
{
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
assert(name != input);
this->parse_node(input);
args.push_back(instructions.at(input));
}
else
{
args.push_back(instructions.at(input));
}
}
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0)
{
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args);
result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
}
else
{
result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
instructions[name] = ops[node.op_type()](get_attributes(node), args);
assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
}
}
}
......@@ -720,25 +744,24 @@ struct onnx_parser
return result;
}
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraphx_unnamed_node";
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node())
{
result[get_name(node)] = node;
if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output())
{
result[output] = node;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment