Commit 6336ed52 authored by Khalique's avatar Khalique
Browse files

change transpose to false, adjusted tests

parent f8ec4fa7
...@@ -80,7 +80,7 @@ struct tf_parser ...@@ -80,7 +80,7 @@ struct tf_parser
} }
std::vector<size_t> std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -95,7 +95,7 @@ struct tf_parser ...@@ -95,7 +95,7 @@ struct tf_parser
} }
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes, const size_t& num_dims) const std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
...@@ -125,7 +125,7 @@ struct tf_parser ...@@ -125,7 +125,7 @@ struct tf_parser
} }
template <class T> template <class T>
T parse_axis(const T& dim, const size_t& num_dims) const T parse_axis(const T& dim, const size_t num_dims) const
{ {
T new_dim = dim; T new_dim = dim;
if(is_nhwc and num_dims >= 4) if(is_nhwc and num_dims >= 4)
...@@ -166,7 +166,7 @@ struct tf_parser ...@@ -166,7 +166,7 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims); add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
...@@ -498,7 +498,7 @@ struct tf_parser ...@@ -498,7 +498,7 @@ struct tf_parser
std::vector<size_t> input_dims = args[0]->get_shape().lens(); std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end()); std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size(); size_t num_dims = input_dims.size();
int32_t dim = parse_axis(args[1]->eval().at<int32_t>(), num_dims); int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0) if(dim < 0)
{ {
......
...@@ -164,8 +164,9 @@ TEST_CASE(expanddims_test) ...@@ -164,8 +164,9 @@ TEST_CASE(expanddims_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0); p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto prog = optimize_tf("expanddims_test.pb", true); auto prog = optimize_tf("expanddims_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -176,8 +177,9 @@ TEST_CASE(expanddims_test_neg_dims) ...@@ -176,8 +177,9 @@ TEST_CASE(expanddims_test_neg_dims)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(-1);
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0); p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
auto prog = optimize_tf("expanddims_neg_test.pb", true); auto prog = optimize_tf("expanddims_neg_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment