Commit 613772dd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent d511dd80
...@@ -279,7 +279,7 @@ struct onnx_parser ...@@ -279,7 +279,7 @@ struct onnx_parser
} }
int keep_dims = 1; int keep_dims = 1;
if (contains(attributes, "keepdims")) if(contains(attributes, "keepdims"))
{ {
keep_dims = parse_value(attributes.at("keepdims")).at<int>(); keep_dims = parse_value(attributes.at("keepdims")).at<int>();
} }
...@@ -298,7 +298,7 @@ struct onnx_parser ...@@ -298,7 +298,7 @@ struct onnx_parser
} }
int keep_dims = 1; int keep_dims = 1;
if (contains(attributes, "keepdims")) if(contains(attributes, "keepdims"))
{ {
keep_dims = parse_value(attributes.at("keepdims")).at<int>(); keep_dims = parse_value(attributes.at("keepdims")).at<int>();
} }
......
...@@ -651,12 +651,13 @@ struct cpu_argmax ...@@ -651,12 +651,13 @@ struct cpu_argmax
std::string name() const { return "cpu::argmax"; } std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const int64_t
calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{ {
auto max_val = input(indices.begin(), indices.end()); auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0; int64_t max_index = 0;
for (std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[axis] = i;
if(max_val < input(indices.begin(), indices.end())) if(max_val < input(indices.begin(), indices.end()))
...@@ -680,8 +681,8 @@ struct cpu_argmax ...@@ -680,8 +681,8 @@ struct cpu_argmax
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i); auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num, op.axis); output[i] = this->calc_argmax(input, data_idx, batch_item_num, op.axis);
}); });
}); });
}); });
...@@ -703,12 +704,13 @@ struct cpu_argmin ...@@ -703,12 +704,13 @@ struct cpu_argmin
std::string name() const { return "cpu::argmin"; } std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<class T> template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const int64_t
calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{ {
auto min_val = input(indices.begin(), indices.end()); auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0; int64_t min_index = 0;
for (std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[axis] = i;
if(min_val > input(indices.begin(), indices.end())) if(min_val > input(indices.begin(), indices.end()))
...@@ -732,8 +734,8 @@ struct cpu_argmin ...@@ -732,8 +734,8 @@ struct cpu_argmin
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i); auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num, op.axis); output[i] = this->calc_argmin(input, data_idx, batch_item_num, op.axis);
}); });
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment