Commit d7d8b1c2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add shape and onnx test for argmax and argmin

parent 6596ee39
...@@ -34,7 +34,7 @@ inline __device__ void block_reduce(T* data_ptr, ...@@ -34,7 +34,7 @@ inline __device__ void block_reduce(T* data_ptr,
std::size_t block_size, std::size_t block_size,
std::size_t thr_idx, std::size_t thr_idx,
std::size_t item_num, std::size_t item_num,
std::size_t max_index) std::size_t output_index)
{ {
while(true) while(true)
{ {
...@@ -54,9 +54,7 @@ inline __device__ void block_reduce(T* data_ptr, ...@@ -54,9 +54,7 @@ inline __device__ void block_reduce(T* data_ptr,
if(thr_idx == 0) if(thr_idx == 0)
{ {
// data_ptr[max_index] = data_ptr[output_index] = op(data_ptr[output_index], data_ptr[0]);
// (data_ptr[0] < data_ptr[max_index]) ? data_ptr[max_index] : data_ptr[0];
data_ptr[max_index] = op(data_ptr[max_index], data_ptr[0]);
} }
__syncthreads(); __syncthreads();
......
...@@ -784,6 +784,26 @@ TEST_CASE(logsoftmax) ...@@ -784,6 +784,26 @@ TEST_CASE(logsoftmax)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(argmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::argmax{2, 0}, l0);
auto prog = migraphx::parse_onnx("argmax_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(argmin)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::argmin{3, 0}, l0);
auto prog = migraphx::parse_onnx("argmin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(no_pad_test) TEST_CASE(no_pad_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -380,6 +380,58 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); } ...@@ -380,6 +380,58 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); } TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T>
void test_argop_var()
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0, 1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1, 1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2, 1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3, 1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {3, 4, 5}}, T{0, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 4, 5}}, T{1, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 5}}, T{2, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4}}, T{3, 0}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4, 1}, input);
}
}
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); }
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment