Commit 38bde548 authored by Khalique's avatar Khalique
Browse files

formatting

parent 638db0ea
......@@ -291,7 +291,7 @@ struct tf_parser
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
}
auto l1 = args[1];
if (is_nhwc)
if(is_nhwc)
l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1});
}
......@@ -405,16 +405,16 @@ struct tf_parser
if(is_nhwc)
{
if(l0->name() != "@param")
// squeeze dims are represented for nhwc,
// but intermediate shapes are in nchw
// squeeze dims are represented for nhwc,
// but intermediate shapes are in nchw
l0 = prog.add_instruction(op::transpose{{0, 2, 3, 1}}, args[0]);
}
auto l0_dims = l0->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for (size_t i = 0; i < l0_dims.size(); i++)
for(size_t i = 0; i < l0_dims.size(); i++)
{
if (l0_dims.at(i) == 1)
if(l0_dims.at(i) == 1)
{
op.axes.push_back(i);
}
......@@ -674,14 +674,13 @@ struct tf_parser
case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)};
case tensorflow::DataType::DT_HALF:
{
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
return literal{{shape::half_type, dims}, data_uint16};
}
}
case tensorflow::DataType::DT_DOUBLE:
return literal{
{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
......@@ -726,7 +725,8 @@ struct tf_parser
}
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data, const std::size_t& shape_size)
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const std::size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
......
......@@ -35,17 +35,18 @@ TEST_CASE(add_bcast_test)
TEST_CASE(batchnorm_test)
{
float epsilon = 1.001e-5f;
float momentum = 0.9f;
float epsilon = 1.001e-5f;
float momentum = 0.9f;
migraphx::program p;
migraphx::op::batch_norm_inference op{epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}});
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
std::vector<float> const_vals(32);
std::fill(const_vals.begin(), const_vals.end(), 1.0f);
auto l2 = p.add_parameter("2", s0);
auto l3 = p.add_parameter("3", s0);
auto l4 = p.add_parameter("4", s0);
......@@ -54,7 +55,6 @@ TEST_CASE(batchnorm_test)
auto prog = migraphx::parse_tf("batchnorm_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(biasadd_test)
......@@ -75,7 +75,7 @@ TEST_CASE(biasadd_test)
TEST_CASE(concat_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
......@@ -101,16 +101,16 @@ TEST_CASE(const_test)
TEST_CASE(conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}});
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1,1};
op.dilation = {1,1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3,2,0,1}}, l1);
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l2, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true);
......@@ -164,8 +164,8 @@ TEST_CASE(reshape_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
// in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1,1,1,16}});
p.add_instruction(migraphx::op::reshape{{1,1,1,16}}, l0);
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false);
EXPECT(p == prog);
......@@ -174,10 +174,10 @@ TEST_CASE(reshape_test)
TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto dims = l0->get_shape().lens();
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false);
......@@ -187,7 +187,7 @@ TEST_CASE(softmax_test)
TEST_CASE(squeeze_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1,2,3,1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment