Commit 1ff8fc24 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup in the rewrite_rnn.cpp file

parent 1ea5faef
...@@ -751,17 +751,14 @@ struct onnx_parser ...@@ -751,17 +751,14 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform( std::copy(names.begin(), names.end(), vec_names.begin());
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0); return (map_actv_funcs.count(name) == 0);
})) });
if (name_it != vec_names.end())
{ {
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported"); MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
} }
...@@ -845,8 +842,7 @@ struct onnx_parser ...@@ -845,8 +842,7 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform( std::copy(names.begin(), names.end(), vec_names.begin());
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
// need 4 activation functions // need 4 activation functions
...@@ -884,13 +880,11 @@ struct onnx_parser ...@@ -884,13 +880,11 @@ struct onnx_parser
} }
} }
if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0); return (map_actv_funcs.count(name) == 0);
})) });
{ if (name_it != vec_names.end())
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { {
return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported"); MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
} }
...@@ -975,8 +969,7 @@ struct onnx_parser ...@@ -975,8 +969,7 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform( std::copy(names.begin(), names.end(), vec_names.begin());
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
...@@ -1057,13 +1050,11 @@ struct onnx_parser ...@@ -1057,13 +1050,11 @@ struct onnx_parser
} }
} }
if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0); return (map_actv_funcs.count(name) == 0);
})) });
if (name_it != vec_names.end())
{ {
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported"); MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment