Commit a7349547 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add test example to argmax and argmin operators.

parent 7d74dd8c
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
namespace migraphx { namespace migraphx {
...@@ -13,13 +12,53 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,13 +12,53 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T>
inline __device__ void reduce_argmax(T* data_ptr,
int64_t* index_ptr,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t max_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
if(data_ptr[i] < data_ptr[i + stride])
{
data_ptr[i] = data_ptr[i + stride];
index_ptr[i] = index_ptr[i + stride];
}
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
if(data_ptr[max_index] < data_ptr[0])
{
data_ptr[max_index] = data_ptr[0];
index_ptr[max_index] = index_ptr[0];
}
}
__syncthreads();
}
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = arg.get_shape().lens(); auto lens = arg.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t batch_item_num = lens[axis]; size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::float_type, batch_lens}; migraphx::shape batch_shape{shape::int64_type, batch_lens};
auto arg_shape = arg.get_shape();
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
// use one block for items in one batch. // use one block for items in one batch.
...@@ -33,7 +72,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int ...@@ -33,7 +72,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_idx = batch.multi(blk_idx); auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
namespace migraphx { namespace migraphx {
...@@ -13,6 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,6 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T>
inline __device__ void reduce_argmin(T* data_ptr,
int64_t* index_ptr,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t min_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
if(data_ptr[i] > data_ptr[i + stride])
{
data_ptr[i] = data_ptr[i + stride];
index_ptr[i] = index_ptr[i + stride];
}
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
if(data_ptr[min_index] > data_ptr[0])
{
data_ptr[min_index] = data_ptr[0];
index_ptr[min_index] = index_ptr[0];
}
}
__syncthreads();
}
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = arg.get_shape().lens(); auto lens = arg.get_shape().lens();
......
...@@ -941,9 +941,6 @@ TEST_CASE(softmax_simple_test) ...@@ -941,9 +941,6 @@ TEST_CASE(softmax_simple_test)
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(2); std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
std::cout << v << "\t";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
...@@ -1138,6 +1135,129 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1138,6 +1135,129 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
template<int KeepDims>
void argmax_test_0()
{
migraphx::program p;
std::vector<float> data = {
1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{0, KeepDims}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_00) { argmax_test_0<0>(); }
TEST_CASE(argmax_test_01) { argmax_test_0<1>(); }
TEST_CASE(argmax_test_1)
{
migraphx::program p;
std::vector<float> data = {
1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{1, 0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_2)
{
migraphx::program p;
std::vector<float> data = {
1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{2, 0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
template<int KeepDims>
void argmin_test_0()
{
migraphx::program p;
std::vector<float> data = {
1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{0, KeepDims}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_00) { argmin_test_0<0>(); }
TEST_CASE(argmin_test_01) { argmin_test_0<1>(); }
TEST_CASE(argmin_test_1)
{
migraphx::program p;
std::vector<float> data = {
1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{1, 0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_2)
{
migraphx::program p;
std::vector<float> data = {
1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{2, 0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -611,6 +611,32 @@ template struct test_softmax<1>; ...@@ -611,6 +611,32 @@ template struct test_softmax<1>;
template struct test_softmax<2>; template struct test_softmax<2>;
template struct test_softmax<3>; template struct test_softmax<3>;
template <class T, int Axis, int KeepDims>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, KeepDims>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2047, 2, 1025, 4}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis, KeepDims}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0, 0>;
template struct test_arg_ops<migraphx::op::argmax, 0, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>;
template struct test_arg_ops<migraphx::op::argmin, 0, 0>;
template struct test_arg_ops<migraphx::op::argmin, 0, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment