Commit 38bde548 authored by Khalique's avatar Khalique
Browse files

formatting

parent 638db0ea
...@@ -291,7 +291,7 @@ struct tf_parser ...@@ -291,7 +291,7 @@ struct tf_parser
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0); l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
} }
auto l1 = args[1]; auto l1 = args[1];
if (is_nhwc) if(is_nhwc)
l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]); l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1}); return prog.add_instruction(op, {l0, l1});
} }
...@@ -412,9 +412,9 @@ struct tf_parser ...@@ -412,9 +412,9 @@ struct tf_parser
auto l0_dims = l0->get_shape().lens(); auto l0_dims = l0->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for (size_t i = 0; i < l0_dims.size(); i++) for(size_t i = 0; i < l0_dims.size(); i++)
{ {
if (l0_dims.at(i) == 1) if(l0_dims.at(i) == 1)
{ {
op.axes.push_back(i); op.axes.push_back(i);
} }
...@@ -680,8 +680,7 @@ struct tf_parser ...@@ -680,8 +680,7 @@ struct tf_parser
return literal{{shape::half_type, dims}, data_uint16}; return literal{{shape::half_type, dims}, data_uint16};
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{ return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
...@@ -726,7 +725,8 @@ struct tf_parser ...@@ -726,7 +725,8 @@ struct tf_parser
} }
template <class T> template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data, const std::size_t& shape_size) static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const std::size_t& shape_size)
{ {
std::vector<T> data_vals(shape_size); std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields // check if shape has enough data values given existing fields
......
...@@ -39,7 +39,8 @@ TEST_CASE(batchnorm_test) ...@@ -39,7 +39,8 @@ TEST_CASE(batchnorm_test)
float momentum = 0.9f; float momentum = 0.9f;
migraphx::program p; migraphx::program p;
migraphx::op::batch_norm_inference op{epsilon, momentum, migraphx::op::batch_norm_inference::spatial}; migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}}; migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}});
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
...@@ -54,7 +55,6 @@ TEST_CASE(batchnorm_test) ...@@ -54,7 +55,6 @@ TEST_CASE(batchnorm_test)
auto prog = migraphx::parse_tf("batchnorm_test.pb", true); auto prog = migraphx::parse_tf("batchnorm_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(biasadd_test) TEST_CASE(biasadd_test)
...@@ -107,10 +107,10 @@ TEST_CASE(conv_test) ...@@ -107,10 +107,10 @@ TEST_CASE(conv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1,1}; op.stride = {1, 1};
op.dilation = {1,1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0); auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3,2,0,1}}, l1); auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l2, l3); p.add_instruction(op, l2, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true); auto prog = migraphx::parse_tf("conv_test.pb", true);
...@@ -164,8 +164,8 @@ TEST_CASE(reshape_test) ...@@ -164,8 +164,8 @@ TEST_CASE(reshape_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}}; migraphx::shape s0{migraphx::shape::int32_type, {4}};
// in tf, the second arg is a literal that contains new dimensions // in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1,1,1,16}}); p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1,1,1,16}}, l0); p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false); auto prog = migraphx::parse_tf("reshape_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -187,7 +187,7 @@ TEST_CASE(softmax_test) ...@@ -187,7 +187,7 @@ TEST_CASE(softmax_test)
TEST_CASE(squeeze_test) TEST_CASE(squeeze_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1,2,3,1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0); p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false); auto prog = migraphx::parse_tf("squeeze_test.pb", false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment