Commit b9a476e4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent beaccf94
...@@ -713,10 +713,10 @@ struct cpu_argmax ...@@ -713,10 +713,10 @@ struct cpu_argmax
auto data_idx = this->compute_batch_indices(i, batch_shape); auto data_idx = this->compute_batch_indices(i, batch_shape);
auto max_val = input[i]; auto max_val = input[i];
int64_t max_index = 0; int64_t max_index = 0;
for (size_t j = 1; j < batch_item_num; ++j) for(size_t j = 1; j < batch_item_num; ++j)
{ {
data_idx[op.axis] = j; data_idx[op.axis] = j;
if (max_val < input(data_idx.begin(), data_idx.end())) if(max_val < input(data_idx.begin(), data_idx.end()))
{ {
max_val = input(data_idx.begin(), data_idx.end()); max_val = input(data_idx.begin(), data_idx.end());
max_index = j; max_index = j;
...@@ -774,10 +774,10 @@ struct cpu_argmin ...@@ -774,10 +774,10 @@ struct cpu_argmin
auto data_idx = this->compute_batch_indices(i, batch_shape); auto data_idx = this->compute_batch_indices(i, batch_shape);
auto min_val = input[i]; auto min_val = input[i];
int64_t min_index = 0; int64_t min_index = 0;
for (size_t j = 1; j < batch_item_num; ++j) for(size_t j = 1; j < batch_item_num; ++j)
{ {
data_idx[op.axis] = j; data_idx[op.axis] = j;
if (min_val > input(data_idx.begin(), data_idx.end())) if(min_val > input(data_idx.begin(), data_idx.end()))
{ {
min_val = input(data_idx.begin(), data_idx.end()); min_val = input(data_idx.begin(), data_idx.end());
min_index = j; min_index = j;
......
...@@ -23,8 +23,7 @@ struct hip_argmax ...@@ -23,8 +23,7 @@ struct hip_argmax
std::string name() const { return "gpu::argmax"; } std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -23,8 +23,7 @@ struct hip_argmin ...@@ -23,8 +23,7 @@ struct hip_argmin
std::string name() const { return "gpu::argmin"; } std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -106,8 +106,12 @@ inline __device__ void reduce_argmax(T* data_ptr, ...@@ -106,8 +106,12 @@ inline __device__ void reduce_argmax(T* data_ptr,
} }
template <class T> template <class T>
inline __device__ void inline __device__ void reduce_argmin(T* data_ptr,
reduce_argmin(T* data_ptr, int64_t* index_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index) int64_t* index_ptr,
size_t block_size,
size_t thr_idx,
size_t item_num,
size_t min_index)
{ {
while(true) while(true)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment