Commit b9a476e4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent beaccf94
......@@ -713,10 +713,10 @@ struct cpu_argmax
auto data_idx = this->compute_batch_indices(i, batch_shape);
auto max_val = input[i];
int64_t max_index = 0;
for (size_t j = 1; j < batch_item_num; ++j)
for(size_t j = 1; j < batch_item_num; ++j)
{
data_idx[op.axis] = j;
if (max_val < input(data_idx.begin(), data_idx.end()))
if(max_val < input(data_idx.begin(), data_idx.end()))
{
max_val = input(data_idx.begin(), data_idx.end());
max_index = j;
......@@ -774,10 +774,10 @@ struct cpu_argmin
auto data_idx = this->compute_batch_indices(i, batch_shape);
auto min_val = input[i];
int64_t min_index = 0;
for (size_t j = 1; j < batch_item_num; ++j)
for(size_t j = 1; j < batch_item_num; ++j)
{
data_idx[op.axis] = j;
if (min_val > input(data_idx.begin(), data_idx.end()))
if(min_val > input(data_idx.begin(), data_idx.end()))
{
min_val = input(data_idx.begin(), data_idx.end());
min_index = j;
......
......@@ -23,8 +23,7 @@ struct hip_argmax
std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
......
......@@ -23,8 +23,7 @@ struct hip_argmin
std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
......
......@@ -106,8 +106,12 @@ inline __device__ void reduce_argmax(T* data_ptr,
}
template <class T>
inline __device__ void
reduce_argmin(T* data_ptr, int64_t* index_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index)
inline __device__ void reduce_argmin(T* data_ptr,
int64_t* index_ptr,
size_t block_size,
size_t thr_idx,
size_t item_num,
size_t min_index)
{
while(true)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment