Commit 13122f11 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Code refinement and another change to support the RNN operator.

parent 0427ce2b
......@@ -31,6 +31,7 @@ struct onnx_parser
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> actv_funcs;
onnx_parser()
{
......@@ -85,6 +86,16 @@ struct onnx_parser
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
// init the activation function map
init_actv_func();
}
void init_actv_func()
{
actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
actv_funcs.insert(std::make_pair("relu", op::relu{}));
actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
}
template <class F>
......@@ -656,12 +667,7 @@ struct onnx_parser
activation_func = attributes.at("activations").strings(0);
}
std::unordered_map<std::string, operation> actv_func_map;
actv_func_map.insert(std::make_pair("tanh", op::tanh{}));
actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if(actv_func_map.count(activation_func) == 0)
if(actv_funcs.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
......@@ -690,8 +696,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(
op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
return prog.add_instruction(op::rnn{hidden_size, actv_funcs[activation_func], dirct, clip},
std::move(args));
}
void parse_from(std::istream& is)
......@@ -750,6 +756,16 @@ struct onnx_parser
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
// For RNN, LSTM, and GRU operators, one of the input arguments
// is prim::Undefined, and it is ignored by protobuf. We use a
// hack to ignore this argument for these three operators
std::string op_type = node.op_type();
if((op_type == "RNN" || op_type == "LSTM" || op_type == "GRU") &&
input.empty() == true)
{
continue;
}
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment