Unverified Commit d237415f authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Refactor tf parser (#726)



* refactor files

* formatting

* fix add_bcast test

* fix some tidy errors

* add transpose field and fix more tidy

* formatting

* add pad test and fix more tidy

* formatting

* fix conv parser

* fix depthwiseconv

* remove unsed functions

* remove includes and functions
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 3d24a21c
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_pad : op_parser<parse_pad>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"Pad"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2 * ndims; i += 2)
{
pad_per_dim[i / 2].first = tf_padding[i];
pad_per_dim[i / 2].second = tf_padding[i + 1];
}
parser.reorder_data(pad_per_dim);
std::vector<int64_t> pads(ndims * 2);
for(size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i + ndims] = pad_per_dim[i].second;
}
return info.add_instruction(make_op("pad", {{"pads", pads}}), args.front());
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_pooling : op_parser<parse_pooling>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"AvgPool"}, {"MaxPool"}}; }
instruction_ref parse(const op_desc& opd,
const tf_parser& parser,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::pooling op{starts_with(opd.tf_name, "Max") ? "max" : "average"};
if(contains(info.attributes, "strides"))
{
std::vector<size_t> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(info.attributes, "ksize"))
{
std::vector<size_t> ksize;
copy(info.attributes.at("ksize").list().i(), std::back_inserter(ksize));
parser.reorder_data(ksize);
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
auto l0 = args[0];
if(contains(info.attributes, "padding"))
{
const std::string& pad_mode = info.attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
}
return info.add_instruction(op, l0);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_relu6 : op_parser<parse_relu6>
{
std::vector<op_desc> operators() const { return {{"Relu6"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = info.add_literal(0.0f);
auto max_val = info.add_literal(6.0f);
min_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_reshape : op_parser<parse_reshape>
{
std::vector<op_desc> operators() const { return {{"Reshape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval();
std::vector<int64_t> dims;
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}),
info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_shape : op_parser<parse_shape>
{
std::vector<op_desc> operators() const { return {{"Shape"}}; }
// Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return info.add_literal(migraphx::literal{s, vec_shape});
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_slice : op_parser<parse_slice>
{
std::vector<op_desc> operators() const { return {{"Slice"}}; }
// Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto size = args[2]->eval().get<int32_t>().to_vector();
auto axes = args[0]->get_shape().lens();
size_t num_axes = axes.size();
std::vector<int64_t> axes_int64(axes.begin(), axes.end());
std::vector<int64_t> starts_int64(starts.begin(), starts.end());
std::vector<int64_t> ends(num_axes);
std::vector<int64_t> op_axes(num_axes);
std::iota(op_axes.begin(), op_axes.end(), 0);
for(size_t i = 0; i < num_axes; i++)
{
if(size[i] == -1)
ends[i] = axes_int64[i];
else
ends[i] = starts_int64[i] + size[i];
}
auto op = make_op("slice", {{"starts", starts_int64}, {"ends", ends}, {"axes", op_axes}});
return info.add_instruction(op, info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_softmax : op_parser<parse_softmax>
{
std::vector<op_desc> operators() const { return {{"Softmax"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
int axis = -1;
auto num_dims = args[0]->get_shape().lens().size();
if(contains(info.attributes, "axis"))
{
axis = static_cast<int>(info.attributes.at("axis").i());
}
axis = tune_axis(num_dims, axis, "tf_parse_softmax");
return info.add_instruction(make_op("softmax", {{"axis", axis}}),
info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_split : op_parser<parse_split>
{
std::vector<op_desc> operators() const { return {{"Split"}, {"SplitV"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
bool vector_as_input = args.size() == 3;
int num_outputs = 1;
auto axis_arg = args[0];
auto input_arg = args[1];
if(vector_as_input)
{
input_arg = args[0];
axis_arg = args[2];
}
if(contains(info.attributes, "num_split"))
num_outputs = info.attributes.at("num_split").i();
std::vector<int> splits(num_outputs);
std::vector<int> slice_pos{0};
if(vector_as_input)
{
splits = args[1]->eval().get<int32_t>().to_vector();
num_outputs = splits.size();
}
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{
info.add_instruction(make_op("identity"), input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
int axis = axis_arg->eval().at<int32_t>();
// ensure split is made evenly if "num_split" is used
assert(vector_as_input or lens[axis] % num_outputs == 0);
auto split_size = lens[axis] / num_outputs;
// push back first end point of slice
if(vector_as_input)
{
slice_pos.push_back(splits[0]);
}
else
{
slice_pos.push_back(split_size);
}
// calculate remaining end points for each slice
for(auto i = 1; i < num_outputs; i++)
{
if(vector_as_input)
{
splits[i] += splits[i - 1];
slice_pos.push_back(splits[i]);
}
else
{
slice_pos.push_back((i + 1) * split_size);
}
}
std::vector<instruction_ref> result;
for(auto i = 0; i < num_outputs; i++)
{
std::vector<int64_t> axes(num_dims);
std::iota(axes.begin(), axes.end(), 0);
std::vector<int64_t> starts(num_dims, 0);
std::vector<int64_t> ends(lens.begin(), lens.end());
starts[axis] = slice_pos[i];
ends[axis] = slice_pos[i + 1];
auto op = make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}});
result.push_back(info.add_instruction(op, input_arg));
}
return result;
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_squeeze : op_parser<parse_squeeze>
{
std::vector<op_desc> operators() const { return {{"Squeeze"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto input_dims = args[0]->get_shape().lens();
auto axes = info.attributes.at("squeeze_dims").list().i();
std::vector<int64_t> op_axes(axes.begin(), axes.end());
if(op_axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(size_t i = 0; i < input_dims.size(); i++)
{
if(input_dims.at(i) == 1)
{
op_axes.push_back(i);
}
}
}
return info.add_instruction(make_op("squeeze", {{"axes", op_axes}}),
info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_strideslice : op_parser<parse_strideslice>
{
std::vector<op_desc> operators() const { return {{"StridedSlice"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
std::vector<int64_t> op_starts(starts.begin(), starts.end());
std::vector<int64_t> op_ends(ends.begin(), ends.end());
std::vector<int64_t> op_axes(num_axes);
std::iota(op_axes.begin(), op_axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(info.attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(info.attributes.at("begin_mask").i());
if(contains(info.attributes, "end_mask"))
end_mask = static_cast<uint32_t>(info.attributes.at("end_mask").i());
if(contains(info.attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(info.attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op_starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op_ends.at(i) = axes.at(i);
}
}
auto op = make_op("slice", {{"starts", op_starts}, {"ends", op_ends}, {"axes", op_axes}});
auto l1 = info.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
return info.add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_transpose : op_parser<parse_transpose>
{
std::vector<op_desc> operators() const { return {{"Transpose"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> dims(perm.begin(), perm.end());
return info.add_instruction(make_op("transpose", {{"dims", dims}}), args.front());
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
This diff is collapsed.
...@@ -338,6 +338,15 @@ def pack_test_nhwc(g1): ...@@ -338,6 +338,15 @@ def pack_test_nhwc(g1):
tf.stack([g1_input, g2_input, g3_input], axis=3, name='pack1') tf.stack([g1_input, g2_input, g3_input], axis=3, name='pack1')
@tf_test
def pad_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 4), name='0')
paddings = tf.constant([[1, 1], [2, 2]])
tf.pad(g1_input, paddings, name='pad1')
@tf_test @tf_test
def pooling_test(g1): def pooling_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -579,10 +588,11 @@ if __name__ == '__main__': ...@@ -579,10 +588,11 @@ if __name__ == '__main__':
mean_test() mean_test()
mean_test_nhwc() mean_test_nhwc()
mul_test() mul_test()
onehot_test()
noop_test() noop_test()
onehot_test()
pack_test() pack_test()
pack_test_nhwc() pack_test_nhwc()
pad_test()
pooling_test() pooling_test()
pow_test() pow_test()
relu_test() relu_test()
......
...@@ -71,10 +71,8 @@ TEST_CASE(add_bcast_test) ...@@ -71,10 +71,8 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = auto l2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l0);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l2, l3); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -546,6 +544,23 @@ TEST_CASE(pack_test_nhwc) ...@@ -546,6 +544,23 @@ TEST_CASE(pack_test_nhwc)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pad_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
std::vector<int> pad_literals{1, 1, 2, 2};
std::vector<int> pads{1, 2, 1, 2};
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
auto prog = optimize_tf("pad_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pooling_test) TEST_CASE(pooling_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment