"docs/vscode:/vscode.git/clone" did not exist on "09f1a247ce48811de3ea9c73f71398322813c87d"
Commit 511c8d8f authored by Paul's avatar Paul
Browse files

Merge from develop

parents 9b7c44ab 2a2c146c
...@@ -346,60 +346,39 @@ TEST_CASE(gather) ...@@ -346,60 +346,39 @@ TEST_CASE(gather)
} }
} }
TEST_CASE(logsoftmax) template <class T>
void test_softmax_variations()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4; int axis = 4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, throws_shape(T{axis}, input);
migraphx::op::logsoftmax{axis},
input);
} }
}
{ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
{ TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
}
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
......
...@@ -124,6 +124,7 @@ TEST_CASE(conv_test) ...@@ -124,6 +124,7 @@ TEST_CASE(conv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1); auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
...@@ -145,6 +146,7 @@ TEST_CASE(depthwiseconv_test) ...@@ -145,6 +146,7 @@ TEST_CASE(depthwiseconv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment