Commit b16ab01d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent 31b53e5a
......@@ -53,14 +53,15 @@ inline __device__ void reduce_argmax(T* data_ptr,
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{
auto lens = arg.get_shape().lens();
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
auto batch_lens = lens;
size_t batch_item_num = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int64_type, batch_lens};
auto arg_shape = arg.get_shape();
migraphx::shape batch_shape{arg_shape.type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
// use one block for items in one batch.
const size_t max_block_size = 1024;
size_t block_size = 1;
......@@ -74,14 +75,15 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_idx = batch.multi(blk_idx);
auto batch_idx = batch_s.multi(blk_idx);
auto data_idx = batch_idx;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
MIGRAPHX_DEVICE_SHARED int64_t lds_index[max_block_size + 1];
// load data to lds_data
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
size_t remaining_item_num = batch_item_num;
lds_data[max_block_size] = input[0];
data_idx[axis] = 0;
lds_data[max_block_size] = input[arg_s.index(data_idx)];
lds_index[max_block_size] = 0;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
......@@ -89,7 +91,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
{
data_idx[axis] = i;
lds_index[thr_idx] = i;
lds_data[thr_idx] = input[data_idx];
lds_data[thr_idx] = input[arg_s.index(data_idx)];
}
__syncthreads();
......@@ -101,7 +103,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
if(thr_idx == 0)
{
output[batch_idx] = lds_index[max_block_size];
output[batch_s.index(batch_idx)] = lds_index[max_block_size];
}
});
});
......
......@@ -53,13 +53,15 @@ inline __device__ void reduce_argmin(T* data_ptr,
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
{
auto lens = arg.get_shape().lens();
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
auto batch_lens = lens;
size_t batch_item_num = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::float_type, batch_lens};
migraphx::shape batch_shape{arg_shape.type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
// use one block for items in one batch.
const size_t max_block_size = 1024;
size_t block_size = 1;
......@@ -71,16 +73,17 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_idx = batch.multi(blk_idx);
auto batch_idx = batch_s.multi(blk_idx);
auto data_idx = batch_idx;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
MIGRAPHX_DEVICE_SHARED int64_t lds_index[max_block_size + 1];
// load data to lds_data
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
size_t remaining_item_num = batch_item_num;
lds_data[max_block_size] = input[0];
data_idx[axis] = 0;
lds_data[max_block_size] = input[arg_s.index(data_idx)];
lds_index[max_block_size] = 0;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
......@@ -88,7 +91,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
{
data_idx[axis] = i;
lds_index[thr_idx] = i;
lds_data[thr_idx] = input[data_idx];
lds_data[thr_idx] = input[arg_s.index(data_idx)];
}
__syncthreads();
......@@ -100,7 +103,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
if(thr_idx == 0)
{
output[batch_idx] = lds_index[max_block_size];
output[batch_s.index(batch_idx)] = lds_index[max_block_size];
}
});
});
......
......@@ -3,6 +3,7 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
......@@ -59,6 +60,56 @@ inline __device__ void block_reduce(T* data_ptr,
__syncthreads();
}
template <class T, class F>
struct pair_max_op
{
using type = std::pair<T, F>;
type operator()(type x, type y) const { return (x.first > y.first) ? x : y; }
};
template <class T, class F>
struct pair_min_op
{
using type = std::pair<T, F>;
type operator()(type x, type y) const { return (x.first < y.first) ? x : y; }
};
template <class T, class Op>
inline __device__ void block_reduce_pair(T* data_ptr,
int64_t* index_ptr,
Op op,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t output_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
auto output = op({data_ptr[i], index_ptr[i]}, {data_ptr[i + stride], index_ptr[i + stride]});
data_ptr[i] = output.first;
index_ptr[i] = output.second;
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
auto output = op({data_ptr[output_index], index_ptr[output_index]}, {data_ptr[0], index_ptr[0]});
data_ptr[output_index] = output.first;
index_ptr[output_index] = output.second;
}
__syncthreads();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -617,7 +617,7 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, KeepDims>>
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2047, 2, 1025, 4}};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis, KeepDims}, param);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment