Commit b9a476e4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent beaccf94
...@@ -702,23 +702,23 @@ struct cpu_argmax ...@@ -702,23 +702,23 @@ struct cpu_argmax
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = args.front().get_shape().lens(); auto batch_lens = args.front().get_shape().lens();
size_t batch_item_num = batch_lens[op.axis]; size_t batch_item_num = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = this->compute_batch_indices(i, batch_shape); auto data_idx = this->compute_batch_indices(i, batch_shape);
auto max_val = input[i]; auto max_val = input[i];
int64_t max_index = 0; int64_t max_index = 0;
for (size_t j = 1; j < batch_item_num; ++j) for(size_t j = 1; j < batch_item_num; ++j)
{ {
data_idx[op.axis] = j; data_idx[op.axis] = j;
if (max_val < input(data_idx.begin(), data_idx.end())) if(max_val < input(data_idx.begin(), data_idx.end()))
{ {
max_val = input(data_idx.begin(), data_idx.end()); max_val = input(data_idx.begin(), data_idx.end());
max_index = j; max_index = j;
} }
} }
...@@ -763,23 +763,23 @@ struct cpu_argmin ...@@ -763,23 +763,23 @@ struct cpu_argmin
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = args.front().get_shape().lens(); auto batch_lens = args.front().get_shape().lens();
size_t batch_item_num = batch_lens[op.axis]; size_t batch_item_num = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = this->compute_batch_indices(i, batch_shape); auto data_idx = this->compute_batch_indices(i, batch_shape);
auto min_val = input[i]; auto min_val = input[i];
int64_t min_index = 0; int64_t min_index = 0;
for (size_t j = 1; j < batch_item_num; ++j) for(size_t j = 1; j < batch_item_num; ++j)
{ {
data_idx[op.axis] = j; data_idx[op.axis] = j;
if (min_val > input(data_idx.begin(), data_idx.end())) if(min_val > input(data_idx.begin(), data_idx.end()))
{ {
min_val = input(data_idx.begin(), data_idx.end()); min_val = input(data_idx.begin(), data_idx.end());
min_index = j; min_index = j;
} }
} }
......
...@@ -23,8 +23,7 @@ struct hip_argmax ...@@ -23,8 +23,7 @@ struct hip_argmax
std::string name() const { return "gpu::argmax"; } std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -23,8 +23,7 @@ struct hip_argmin ...@@ -23,8 +23,7 @@ struct hip_argmin
std::string name() const { return "gpu::argmin"; } std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -106,8 +106,12 @@ inline __device__ void reduce_argmax(T* data_ptr, ...@@ -106,8 +106,12 @@ inline __device__ void reduce_argmax(T* data_ptr,
} }
template <class T> template <class T>
inline __device__ void inline __device__ void reduce_argmin(T* data_ptr,
reduce_argmin(T* data_ptr, int64_t* index_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index) int64_t* index_ptr,
size_t block_size,
size_t thr_idx,
size_t item_num,
size_t min_index)
{ {
while(true) while(true)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment