Commit 195753c4 authored by Khalique's avatar Khalique
Browse files

unified softmax and logsoftmax tests

parent 7df6328c
...@@ -346,61 +346,68 @@ TEST_CASE(gather) ...@@ -346,61 +346,68 @@ TEST_CASE(gather)
} }
} }
TEST_CASE(logsoftmax) template<class T>
void test_softmax_variations(T, bool is_logsoftmax)
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis}, T{0},
input); input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis}, T{1},
input); input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis}, T{2},
input); input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis}, T{3},
input); input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4; throws_shape(T{5}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5; throws_shape(T{-1}, input);
throws_shape(migraphx::op::logsoftmax{axis}, input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1; if(is_logsoftmax)
throws_shape(migraphx::op::logsoftmax{axis}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
T{4},
input);
else
throws_shape(T{4}, input);
} }
} }
TEST_CASE(softmax)
{
test_softmax_variations(migraphx::op::softmax{}, false);
}
TEST_CASE(logsoftmax)
{
test_softmax_variations(migraphx::op::logsoftmax{}, true);
}
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment