Commit fa983162 authored by Khalique's avatar Khalique
Browse files

formatting

parent ee39cf0c
......@@ -56,9 +56,10 @@ struct tf_parser
std::vector<T> new_axes;
if(is_nhwc)
{
std::transform(axes.begin(), axes.end(), std::back_inserter(new_axes), [&](size_t axis) {
return parse_axis(axis);
});
std::transform(axes.begin(),
axes.end(),
std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); });
}
return new_axes;
}
......@@ -334,14 +335,13 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights});
}
instruction_ref parse_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2,3};
std::vector<int32_t> hw_axes{2, 3};
if(axes == hw_axes and keep_dims)
{
op::pooling op{"average"};
......@@ -358,23 +358,23 @@ struct tf_parser
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2*ndims; i+= 2)
for(size_t i = 0; i < 2 * ndims; i += 2)
{
pad_per_dim[i/2].first = tf_padding[i];
pad_per_dim[i/2].second = tf_padding[i+1];
pad_per_dim[i / 2].first = tf_padding[i];
pad_per_dim[i / 2].second = tf_padding[i + 1];
}
reorder_data(pad_per_dim);
op::pad op;
std::vector<int64_t> pads(ndims*2);
for (size_t i = 0; i < ndims; i++)
std::vector<int64_t> pads(ndims * 2);
for(size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i+ndims] = pad_per_dim[i].second;
pads[i] = pad_per_dim[i].first;
pads[i + ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return prog.add_instruction(op, args.front());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment