Commit 68d5b22b authored by Khalique's avatar Khalique
Browse files

add expanddims plus tests

parent 15eb1987
...@@ -37,7 +37,7 @@ struct tf_parser ...@@ -37,7 +37,7 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -45,14 +45,14 @@ struct tf_parser ...@@ -45,14 +45,14 @@ struct tf_parser
if(is_nhwc) if(is_nhwc)
{ {
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) { std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis); return parse_axis(axis, num_dims);
}); });
} }
return axes; return axes;
} }
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const std::vector<T> parse_axes(std::vector<T> axes, const size_t& num_dims) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
...@@ -60,7 +60,7 @@ struct tf_parser ...@@ -60,7 +60,7 @@ struct tf_parser
std::transform(axes.begin(), std::transform(axes.begin(),
axes.end(), axes.end(),
std::back_inserter(new_axes), std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); }); [&](size_t axis) { return parse_axis(axis, num_dims); });
return new_axes; return new_axes;
} }
return axes; return axes;
...@@ -75,17 +75,17 @@ struct tf_parser ...@@ -75,17 +75,17 @@ struct tf_parser
std::vector<T> new_data(prev_data.size()); std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++) for(size_t i = 0; i < new_data.size(); i++)
{ {
auto new_idx = parse_axis(i); auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i); new_data.at(new_idx) = prev_data.at(i);
} }
prev_data = new_data; prev_data = new_data;
} }
template <class T> template <class T>
T parse_axis(const T& dim) const T parse_axis(const T& dim, const size_t& num_dims) const
{ {
T new_dim = dim; T new_dim = dim;
if(is_nhwc) if(is_nhwc and num_dims >= 4)
{ {
switch(dim) switch(dim)
{ {
...@@ -121,6 +121,7 @@ struct tf_parser ...@@ -121,6 +121,7 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul); add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
...@@ -251,7 +252,7 @@ struct tf_parser ...@@ -251,7 +252,7 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>()); size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>(), args[0]->get_shape().lens().size());
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -470,6 +471,24 @@ struct tf_parser ...@@ -470,6 +471,24 @@ struct tf_parser
return prog.add_instruction(op, {l0, new_weights}); return prog.add_instruction(op, {l0, new_weights});
} }
instruction_ref parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = parse_axis(args[1]->eval().at<int32_t>(), num_dims);
if (dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return prog.add_instruction(op::reshape{new_dims}, args[0]);
}
instruction_ref instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -499,11 +518,12 @@ struct tf_parser ...@@ -499,11 +518,12 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4) if(axes == hw_axes and lens.size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
...@@ -534,8 +554,7 @@ struct tf_parser ...@@ -534,8 +554,7 @@ struct tf_parser
" must be smaller than input size " + to_string(input_size)); " must be smaller than input size " + to_string(input_size));
} }
// check if input arg needs axis to be converted to NCHW // check if input arg needs axis to be converted to NCHW
if(input_size >= 4) axis = parse_axis(axis, input_size);
axis = parse_axis(axis);
std::transform( std::transform(
args.begin(), args.begin(),
...@@ -676,14 +695,15 @@ struct tf_parser ...@@ -676,14 +695,15 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto axes = parse_axes(attributes, "squeeze_dims"); auto input_dims = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims", input_dims.size());
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for(size_t i = 0; i < args0_dims.size(); i++) for(size_t i = 0; i < input_dims.size(); i++)
{ {
if(args0_dims.at(i) == 1) if(input_dims.at(i) == 1)
{ {
op.axes.push_back(i); op.axes.push_back(i);
} }
...@@ -723,10 +743,7 @@ struct tf_parser ...@@ -723,10 +743,7 @@ struct tf_parser
if(((shrink_axis_mask >> i) & bitwise_compare) == 1) if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
if(num_axes >= 4) squeeze_axes = parse_axes(squeeze_axes, num_axes);
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]); auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0); return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
......
...@@ -146,6 +146,20 @@ TEST_CASE(depthwiseconv_test) ...@@ -146,6 +146,20 @@ TEST_CASE(depthwiseconv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(expanddims_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2,3,4}});
p.add_literal(-1);
p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{2,3,4,1}}, l0);
p.add_instruction(migraphx::op::reshape{{1,2,3,4}}, l0);
auto prog = migraphx::parse_tf("expanddims_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(identity_test) TEST_CASE(identity_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment