"example/vscode:/vscode.git/clone" did not exist on "3ab3cf390755851a04341e6622815b076521112e"
Commit 613772dd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent d511dd80
......@@ -279,7 +279,7 @@ struct onnx_parser
}
int keep_dims = 1;
if (contains(attributes, "keepdims"))
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
......@@ -298,7 +298,7 @@ struct onnx_parser
}
int keep_dims = 1;
if (contains(attributes, "keepdims"))
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
......
......@@ -651,12 +651,13 @@ struct cpu_argmax
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
template <class T>
int64_t
calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{
auto max_val = input(indices.begin(), indices.end());
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for (std::size_t i = 1; i < item_num; ++i)
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(max_val < input(indices.begin(), indices.end()))
......@@ -680,8 +681,8 @@ struct cpu_argmax
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num, op.axis);
auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num, op.axis);
});
});
});
......@@ -703,12 +704,13 @@ struct cpu_argmin
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
template <class T>
int64_t
calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{
auto min_val = input(indices.begin(), indices.end());
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for (std::size_t i = 1; i < item_num; ++i)
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(min_val > input(indices.begin(), indices.end()))
......@@ -732,8 +734,8 @@ struct cpu_argmin
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num, op.axis);
auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num, op.axis);
});
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment