Commit d511dd80 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix cppcheck errors.

parent eeda0606
......@@ -40,7 +40,7 @@ struct argmax
}
lens[axis] = 1;
if(!keep_dims)
if(keep_dims == 0)
{
lens.erase(lens.begin() + axis);
}
......
......@@ -40,7 +40,7 @@ struct argmin
}
lens[axis] = 1;
if(!keep_dims)
if(keep_dims == 0)
{
lens.erase(lens.begin() + axis);
}
......
......@@ -278,7 +278,13 @@ struct onnx_parser
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::argmax{axis}, std::move(args));
int keep_dims = 1;
if (contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
return prog.add_instruction(op::argmax{axis, keep_dims}, std::move(args));
}
instruction_ref parse_argmin(const std::string&,
......@@ -291,7 +297,13 @@ struct onnx_parser
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::argmin{axis}, std::move(args));
int keep_dims = 1;
if (contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
return prog.add_instruction(op::argmin{axis, keep_dims}, std::move(args));
}
instruction_ref
......
......@@ -651,6 +651,24 @@ struct cpu_argmax
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for (std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(max_val < input(indices.begin(), indices.end()))
{
max_val = input(indices.begin(), indices.end());
max_index = i;
}
}
return max_index;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......@@ -663,19 +681,7 @@ struct cpu_argmax
args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i);
auto max_val = input[i];
int64_t max_index = 0;
for(std::size_t j = 1; j < batch_item_num; ++j)
{
data_idx[op.axis] = j;
if(max_val < input(data_idx.begin(), data_idx.end()))
{
max_val = input(data_idx.begin(), data_idx.end());
max_index = j;
}
}
output[i] = max_index;
output[i] = this->calc_argmax(input, data_idx, batch_item_num, op.axis);
});
});
});
......@@ -697,6 +703,24 @@ struct cpu_argmin
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for (std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(min_val > input(indices.begin(), indices.end()))
{
min_val = input(indices.begin(), indices.end());
min_index = i;
}
}
return min_index;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......@@ -709,19 +733,7 @@ struct cpu_argmin
args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i);
auto min_val = input[i];
int64_t min_index = 0;
for(std::size_t j = 1; j < batch_item_num; ++j)
{
data_idx[op.axis] = j;
if(min_val > input(data_idx.begin(), data_idx.end()))
{
min_val = input(data_idx.begin(), data_idx.end());
min_index = j;
}
}
output[i] = min_index;
output[i] = this->calc_argmin(input, data_idx, batch_item_num, op.axis);
});
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment