Commit 13122f11 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Code refinement and another change to support the RNN operator.

parent 0427ce2b
...@@ -31,6 +31,7 @@ struct onnx_parser ...@@ -31,6 +31,7 @@ struct onnx_parser
bool is_pytorch = false; bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> actv_funcs;
onnx_parser() onnx_parser()
{ {
...@@ -85,6 +86,16 @@ struct onnx_parser ...@@ -85,6 +86,16 @@ struct onnx_parser
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
// init the activation function map
init_actv_func();
}
void init_actv_func()
{
actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
actv_funcs.insert(std::make_pair("relu", op::relu{}));
actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
} }
template <class F> template <class F>
...@@ -656,12 +667,7 @@ struct onnx_parser ...@@ -656,12 +667,7 @@ struct onnx_parser
activation_func = attributes.at("activations").strings(0); activation_func = attributes.at("activations").strings(0);
} }
std::unordered_map<std::string, operation> actv_func_map; if(actv_funcs.count(activation_func) == 0)
actv_func_map.insert(std::make_pair("tanh", op::tanh{}));
actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if(actv_func_map.count(activation_func) == 0)
{ {
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported"); MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
} }
...@@ -690,8 +696,8 @@ struct onnx_parser ...@@ -690,8 +696,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction( return prog.add_instruction(op::rnn{hidden_size, actv_funcs[activation_func], dirct, clip},
op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args)); std::move(args));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
...@@ -750,6 +756,16 @@ struct onnx_parser ...@@ -750,6 +756,16 @@ struct onnx_parser
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// For RNN, LSTM, and GRU operators, one of the input arguments
// is prim::Undefined, and it is ignored by protobuf. We use a
// hack to ignore this argument for these three operators
std::string op_type = node.op_type();
if((op_type == "RNN" || op_type == "LSTM" || op_type == "GRU") &&
input.empty() == true)
{
continue;
}
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); auto&& iname = get_name(nodes.at(input));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment