Commit 638db0ea authored by Khalique's avatar Khalique
Browse files

added final tests, adjusted parser to fix edge cases

parent 1bd9d6a0
......@@ -169,6 +169,13 @@ struct tf_parser
epsilon = attributes.at("epsilon").f();
}
auto l0 = args[0];
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
};
args[0] = l0;
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
......@@ -283,7 +290,9 @@ struct tf_parser
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
}
auto l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
auto l1 = args[1];
if (is_nhwc)
l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1});
}
......@@ -395,7 +404,21 @@ struct tf_parser
auto l0 = args[0];
if(is_nhwc)
{
l0 = prog.add_instruction(op::transpose{{0, 2, 3, 1}}, args[0]);
if(l0->name() != "@param")
// squeeze dims are represented for nhwc,
// but intermediate shapes are in nchw
l0 = prog.add_instruction(op::transpose{{0, 2, 3, 1}}, args[0]);
}
auto l0_dims = l0->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for (size_t i = 0; i < l0_dims.size(); i++)
{
if (l0_dims.at(i) == 1)
{
op.axes.push_back(i);
}
}
}
return prog.add_instruction(op, l0);
}
......@@ -565,7 +588,7 @@ struct tf_parser
{
dims = {1};
}
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data
{
const std::string& s = t.tensor_content();
......@@ -635,26 +658,30 @@ struct tf_parser
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, t.float_val().begin(), t.float_val().end()};
return literal{{shape::float_type, dims}, get_data_vals(t.float_val(), shape_size)};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, t.int64_val().begin(), t.int64_val().end()};
return literal{{shape::int64_type, dims}, get_data_vals(t.int64_val(), shape_size)};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, t.bool_val().begin(), t.bool_val().end()};
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)};
case tensorflow::DataType::DT_HALF:
return literal{{shape::half_type, dims}, t.half_val().begin(), t.half_val().end()};
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
return literal{{shape::half_type, dims}, data_uint16};
}
case tensorflow::DataType::DT_DOUBLE:
return literal{
{shape::double_type, dims}, t.double_val().begin(), t.double_val().end()};
{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
......@@ -698,6 +725,20 @@ struct tf_parser
MIGRAPHX_THROW("Invalid tensor type");
}
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data, const std::size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
if(data.size() == 1)
{
std::fill(data_vals.begin(), data_vals.end(), data[0]);
}
else
copy(data.begin(), data.end(), std::back_inserter(data_vals));
return data_vals;
}
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
{
std::vector<size_t> dims;
......
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:

conv1Conv2D01*
dilations
*
T0*
data_formatNHWC*
strides
*
use_cudnn_on_gpu(*
paddingSAME"
\ No newline at end of file
2
0 Placeholder*
dtype0*
shape
:

softmaxSoftmax0*
T0"
\ No newline at end of file
......@@ -33,6 +33,30 @@ TEST_CASE(add_bcast_test)
EXPECT(p == prog);
}
TEST_CASE(batchnorm_test)
{
float epsilon = 1.001e-5f;
float momentum = 0.9f;
migraphx::program p;
migraphx::op::batch_norm_inference op{epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}});
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
std::vector<float> const_vals(32);
std::fill(const_vals.begin(), const_vals.end(), 1.0f);
auto l2 = p.add_parameter("2", s0);
auto l3 = p.add_parameter("3", s0);
auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4);
auto prog = migraphx::parse_tf("batchnorm_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(biasadd_test)
{
migraphx::program p;
......@@ -48,6 +72,51 @@ TEST_CASE(biasadd_test)
EXPECT(p == prog);
}
TEST_CASE(concat_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
int axis = 1;
// tf uses axis as the third input, and it is in int32 format
p.add_literal(axis);
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(const_test)
{
migraphx::program p;
p.add_literal(1.0f);
auto prog = migraphx::parse_tf("constant_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}});
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1,1};
op.dilation = {1,1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3,2,0,1}}, l1);
p.add_instruction(op, l2, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(identity_test)
{
migraphx::program p;
......@@ -89,4 +158,39 @@ TEST_CASE(relu_test)
EXPECT(p == prog);
}
TEST_CASE(reshape_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
// in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1,1,1,16}});
p.add_instruction(migraphx::op::reshape{{1,1,1,16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto dims = l0->get_shape().lens();
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(squeeze_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1,2,3,1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment