"vscode:/vscode.git/clone" did not exist on "de786192c9c0cdd5ccdfafe7e952ff3165a69792"
Commit ee39cf0c authored by Khalique's avatar Khalique
Browse files

added pad and mean op

parent ff009f50
...@@ -50,6 +50,19 @@ struct tf_parser ...@@ -50,6 +50,19 @@ struct tf_parser
return axes; return axes;
} }
template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const
{
std::vector<T> new_axes;
if(is_nhwc)
{
std::transform(axes.begin(), axes.end(), std::back_inserter(new_axes), [&](size_t axis) {
return parse_axis(axis);
});
}
return new_axes;
}
// tf stores certain attributes such as strides, dilations, as a 4D input. // tf stores certain attributes such as strides, dilations, as a 4D input.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3. // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// This helper function reorders the data to store for the respective operator member variables. // This helper function reorders the data to store for the respective operator member variables.
...@@ -104,6 +117,8 @@ struct tf_parser ...@@ -104,6 +117,8 @@ struct tf_parser
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze);
...@@ -319,6 +334,52 @@ struct tf_parser ...@@ -319,6 +334,52 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref parse_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2,3};
if(axes == hw_axes and keep_dims)
{
op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()};
op.lengths[0] = input_dims[2];
op.lengths[1] = input_dims[3];
return prog.add_instruction(op, args.front());
}
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2*ndims; i+= 2)
{
pad_per_dim[i/2].first = tf_padding[i];
pad_per_dim[i/2].second = tf_padding[i+1];
}
reorder_data(pad_per_dim);
op::pad op;
std::vector<int64_t> pads(ndims*2);
for (size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i+ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return prog.add_instruction(op, args.front());
}
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment