"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "fdfff50d362c45f88f916a60756933405501ed29"
Commit 2c8acfb8 authored by Khalique's avatar Khalique
Browse files

Merge branch 'tf_pb' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into tf_pb_py

parents 75aa4baf fa983162
...@@ -50,6 +50,23 @@ struct tf_parser ...@@ -50,6 +50,23 @@ struct tf_parser
return axes; return axes;
} }
template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const
{
std::vector<T> new_axes;
if(is_nhwc)
{
std::transform(axes.begin(),
axes.end(),
std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); });
}
return new_axes;
}
// tf stores certain attributes such as strides, dilations, as a 4D input.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// This helper function reorders the data to store for the respective operator member variables.
template <class T> template <class T>
void reorder_data(std::vector<T>& prev_data) const void reorder_data(std::vector<T>& prev_data) const
{ {
...@@ -101,6 +118,8 @@ struct tf_parser ...@@ -101,6 +118,8 @@ struct tf_parser
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze);
...@@ -316,6 +335,51 @@ struct tf_parser ...@@ -316,6 +335,51 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3};
if(axes == hw_axes and keep_dims)
{
op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()};
op.lengths[0] = input_dims[2];
op.lengths[1] = input_dims[3];
return prog.add_instruction(op, args.front());
}
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2 * ndims; i += 2)
{
pad_per_dim[i / 2].first = tf_padding[i];
pad_per_dim[i / 2].second = tf_padding[i + 1];
}
reorder_data(pad_per_dim);
op::pad op;
std::vector<int64_t> pads(ndims * 2);
for(size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i + ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return prog.add_instruction(op, args.front());
}
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment