Commit db4df30b authored by Khalique's avatar Khalique
Browse files

formatting

parent 68d5b22b
...@@ -37,7 +37,8 @@ struct tf_parser ...@@ -37,7 +37,8 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -252,7 +253,8 @@ struct tf_parser ...@@ -252,7 +253,8 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>(), args[0]->get_shape().lens().size()); size_t axis =
parse_axis(args[axis_idx]->eval().at<int64_t>(), args[0]->get_shape().lens().size());
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -471,14 +473,15 @@ struct tf_parser ...@@ -471,14 +473,15 @@ struct tf_parser
return prog.add_instruction(op, {l0, new_weights}); return prog.add_instruction(op, {l0, new_weights});
} }
instruction_ref parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args) instruction_ref
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
std::vector<size_t> input_dims = args[0]->get_shape().lens(); std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end()); std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size(); size_t num_dims = input_dims.size();
int32_t dim = parse_axis(args[1]->eval().at<int32_t>(), num_dims); int32_t dim = parse_axis(args[1]->eval().at<int32_t>(), num_dims);
if (dim < 0) if(dim < 0)
{ {
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1); new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
} }
...@@ -696,7 +699,7 @@ struct tf_parser ...@@ -696,7 +699,7 @@ struct tf_parser
{ {
op::squeeze op; op::squeeze op;
auto input_dims = args[0]->get_shape().lens(); auto input_dims = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims", input_dims.size()); auto axes = parse_axes(attributes, "squeeze_dims", input_dims.size());
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
......
...@@ -149,12 +149,12 @@ TEST_CASE(depthwiseconv_test) ...@@ -149,12 +149,12 @@ TEST_CASE(depthwiseconv_test)
TEST_CASE(expanddims_test) TEST_CASE(expanddims_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2,3,4}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(-1); p.add_literal(-1);
p.add_literal(0); p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{2,3,4,1}}, l0); p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
p.add_instruction(migraphx::op::reshape{{1,2,3,4}}, l0); p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto prog = migraphx::parse_tf("expanddims_test.pb", true); auto prog = migraphx::parse_tf("expanddims_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment