Commit 14480a27 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the foreach algorithm to std::any_of"

parent 54bfa2d8
...@@ -750,15 +750,20 @@ struct onnx_parser ...@@ -750,15 +750,20 @@ struct onnx_parser
{ {
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); }); vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) { if (std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(fn) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&] (auto &name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions. // bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
...@@ -879,12 +884,15 @@ struct onnx_parser ...@@ -879,12 +884,15 @@ struct onnx_parser
} }
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { if (std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&] (auto &name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
...@@ -1049,12 +1057,15 @@ struct onnx_parser ...@@ -1049,12 +1057,15 @@ struct onnx_parser
} }
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { if (std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("LSTM: activation function " + std::string(name) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&] (auto &name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment