Commit 511c8d8f authored by Paul's avatar Paul
Browse files

Merge from develop

parents 9b7c44ab 2a2c146c
......@@ -346,60 +346,39 @@ TEST_CASE(gather)
}
}
TEST_CASE(logsoftmax)
template <class T>
void test_softmax_variations()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
throws_shape(T{axis}, input);
}
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
}
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
// 2 inputs arguments
TEST_CASE(matmul)
......
......@@ -124,6 +124,7 @@ TEST_CASE(conv_test)
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
......@@ -145,6 +146,7 @@ TEST_CASE(depthwiseconv_test)
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment