Commit af90a792 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'lstm_operator' into seq2seq_example

parents 89b80be6 7031da28
...@@ -787,17 +787,14 @@ struct onnx_parser ...@@ -787,17 +787,14 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform( std::copy(names.begin(), names.end(), vec_names.begin());
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0); return (map_actv_funcs.count(name) == 0);
})) });
if(name_it != vec_names.end())
{ {
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported"); MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
} }
...@@ -881,8 +878,7 @@ struct onnx_parser ...@@ -881,8 +878,7 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform( std::copy(names.begin(), names.end(), vec_names.begin());
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
// need 4 activation functions // need 4 activation functions
...@@ -920,13 +916,11 @@ struct onnx_parser ...@@ -920,13 +916,11 @@ struct onnx_parser
} }
} }
if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0); return (map_actv_funcs.count(name) == 0);
})) });
if(name_it != vec_names.end())
{ {
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported"); MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
} }
...@@ -1011,8 +1005,7 @@ struct onnx_parser ...@@ -1011,8 +1005,7 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform( std::copy(names.begin(), names.end(), vec_names.begin());
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
...@@ -1093,13 +1086,11 @@ struct onnx_parser ...@@ -1093,13 +1086,11 @@ struct onnx_parser
} }
} }
if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0); return (map_actv_funcs.count(name) == 0);
})) });
if(name_it != vec_names.end())
{ {
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported"); MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment