Commit 38bde548 authored by Khalique's avatar Khalique
Browse files

formatting

parent 638db0ea
......@@ -291,7 +291,7 @@ struct tf_parser
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
}
auto l1 = args[1];
if (is_nhwc)
if(is_nhwc)
l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1});
}
......@@ -412,9 +412,9 @@ struct tf_parser
auto l0_dims = l0->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for (size_t i = 0; i < l0_dims.size(); i++)
for(size_t i = 0; i < l0_dims.size(); i++)
{
if (l0_dims.at(i) == 1)
if(l0_dims.at(i) == 1)
{
op.axes.push_back(i);
}
......@@ -680,8 +680,7 @@ struct tf_parser
return literal{{shape::half_type, dims}, data_uint16};
}
case tensorflow::DataType::DT_DOUBLE:
return literal{
{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
......@@ -726,7 +725,8 @@ struct tf_parser
}
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data, const std::size_t& shape_size)
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const std::size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
......
......@@ -39,7 +39,8 @@ TEST_CASE(batchnorm_test)
float momentum = 0.9f;
migraphx::program p;
migraphx::op::batch_norm_inference op{epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}});
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
......@@ -54,7 +55,6 @@ TEST_CASE(batchnorm_test)
auto prog = migraphx::parse_tf("batchnorm_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(biasadd_test)
......@@ -107,10 +107,10 @@ TEST_CASE(conv_test)
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1,1};
op.dilation = {1,1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3,2,0,1}}, l1);
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l2, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true);
......@@ -164,8 +164,8 @@ TEST_CASE(reshape_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
// in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1,1,1,16}});
p.add_instruction(migraphx::op::reshape{{1,1,1,16}}, l0);
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false);
EXPECT(p == prog);
......@@ -187,7 +187,7 @@ TEST_CASE(softmax_test)
TEST_CASE(squeeze_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1,2,3,1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment