Commit 4012e94a authored by Khalique's avatar Khalique
Browse files

formatting

parent 82201cc0
...@@ -575,7 +575,7 @@ struct tf_parser ...@@ -575,7 +575,7 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = args[1]->eval().get<int32_t>().to_vector(); auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end()); std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
...@@ -591,7 +591,7 @@ struct tf_parser ...@@ -591,7 +591,7 @@ struct tf_parser
{ {
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>()); size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1; int64_t axis = -1;
float on_value = args[2]->eval().at<float>(); float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>(); float off_value = args[3]->eval().at<float>();
......
...@@ -270,8 +270,8 @@ TEST_CASE(mean_test_nhwc) ...@@ -270,8 +270,8 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0); auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_instruction(migraphx::op::reduce_mean{{1,2}}, l1); auto l2 = p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l1);
p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2); p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
auto prog = optimize_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
...@@ -293,11 +293,13 @@ TEST_CASE(mul_test) ...@@ -293,11 +293,13 @@ TEST_CASE(mul_test)
TEST_CASE(onehot_test) TEST_CASE(onehot_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}}); auto l0 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2); p.add_literal(2);
p.add_literal(1.0f); p.add_literal(1.0f);
p.add_literal(0.0f); p.add_literal(0.0f);
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2,2}}, {1, 0, 0, 1}}); auto l1 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
int axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0); p.add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_tf("onehot_test.pb", false); auto prog = optimize_tf("onehot_test.pb", false);
...@@ -489,7 +491,7 @@ TEST_CASE(stridedslice_test) ...@@ -489,7 +491,7 @@ TEST_CASE(stridedslice_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0); auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
...@@ -519,9 +521,9 @@ TEST_CASE(stridedslice_masks_test) ...@@ -519,9 +521,9 @@ TEST_CASE(stridedslice_masks_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0}); p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1}); p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0); auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_instruction(op, l1); auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l2); p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true); auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment