Commit b222af2f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9bf18316
...@@ -12,10 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,10 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument argmax(hipStream_t stream, argument argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
const argument& result,
const argument& arg,
int axis)
{ {
auto lens = arg.get_shape().lens(); auto lens = arg.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
...@@ -30,44 +27,47 @@ argument argmax(hipStream_t stream, ...@@ -30,44 +27,47 @@ argument argmax(hipStream_t stream,
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(arg.get_shape()); hip_tensor_descriptor<n_dim> desc_data(arg.get_shape());
// each block is for one batch // each block is for one batch
const size_t block_size = 1024; const size_t block_size = 1024;
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto batch_idx = desc_batch.multi(blk_idx); auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
MIGRAPHX_DEVICE_SHARED type lds_data[block_size]; MIGRAPHX_DEVICE_SHARED type lds_data[block_size];
MIGRAPHX_DEVICE_SHARED int64_t lds_index[block_size]; MIGRAPHX_DEVICE_SHARED int64_t lds_index[block_size];
// load data to lds_data // load data to lds_data
size_t item_num = n_dims; size_t item_num = n_dims;
for (size_t i = thr_idx; i < n_dims; i += block_size) for(size_t i = thr_idx; i < n_dims; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_index[thr_idx] = i; lds_index[thr_idx] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)]; lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
auto stride = (size + 1) / 2; auto stride = (size + 1) / 2;
while (true) while(true)
{ {
if (thr_idx + stride < size and lds_data[thr_idx] < lds_data[thr_idx + stride]) if(thr_idx + stride < size and
lds_data[thr_idx] < lds_data[thr_idx + stride])
{ {
lds_data[thr_idx] = lds_data[thr_idx + stride]; lds_data[thr_idx] = lds_data[thr_idx + stride];
lds_index[thr_idx] = lds_index[thr_idx + stride]; lds_index[thr_idx] = lds_index[thr_idx + stride];
} }
__syncthreads(); __syncthreads();
size = stride; size = stride;
stride = (stride + 1) / 2; stride = (stride + 1) / 2;
if (size == 1) break; if(size == 1)
break;
} }
if (thr_idx == 0) if(thr_idx == 0)
{ {
output_ptr[blk_idx] = lds_index[0]; output_ptr[blk_idx] = lds_index[0];
} }
......
...@@ -12,10 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,10 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument argmax(hipStream_t stream, argument argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
const argument& result,
const argument& arg,
int axis)
{ {
auto lens = arg.get_shape().lens(); auto lens = arg.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
...@@ -30,44 +27,47 @@ argument argmax(hipStream_t stream, ...@@ -30,44 +27,47 @@ argument argmax(hipStream_t stream,
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(arg.get_shape()); hip_tensor_descriptor<n_dim> desc_data(arg.get_shape());
// each block is for one batch // each block is for one batch
const size_t block_size = 1024; const size_t block_size = 1024;
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto batch_idx = desc_batch.multi(blk_idx); auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
MIGRAPHX_DEVICE_SHARED type lds_data[block_size]; MIGRAPHX_DEVICE_SHARED type lds_data[block_size];
MIGRAPHX_DEVICE_SHARED int64_t lds_index[block_size]; MIGRAPHX_DEVICE_SHARED int64_t lds_index[block_size];
// load data to lds_data // load data to lds_data
size_t item_num = n_dims; size_t item_num = n_dims;
for (size_t i = thr_idx; i < n_dims; i += block_size) for(size_t i = thr_idx; i < n_dims; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_index[thr_idx] = i; lds_index[thr_idx] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)]; lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
auto stride = (size + 1) / 2; auto stride = (size + 1) / 2;
while (true) while(true)
{ {
if (thr_idx + stride < size and lds_data[thr_idx] > lds_data[thr_idx + stride]) if(thr_idx + stride < size and
lds_data[thr_idx] > lds_data[thr_idx + stride])
{ {
lds_data[thr_idx] = lds_data[thr_idx + stride]; lds_data[thr_idx] = lds_data[thr_idx + stride];
lds_index[thr_idx] = lds_index[thr_idx + stride]; lds_index[thr_idx] = lds_index[thr_idx + stride];
} }
__syncthreads(); __syncthreads();
size = stride; size = stride;
stride = (stride + 1) / 2; stride = (stride + 1) / 2;
if (size == 1) break; if(size == 1)
break;
} }
if (thr_idx == 0) if(thr_idx == 0)
{ {
output_ptr[blk_idx] = lds_index[0]; output_ptr[blk_idx] = lds_index[0];
} }
......
...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument argmax(hipStream_t stream, argument argmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
const argument& result,
const argument& arg,
int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument argmin(hipStream_t stream, argument argmin(hipStream_t stream, const argument& result, const argument& arg, int axis);
const argument& result,
const argument& arg,
int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment