Unverified Commit d237415f authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Refactor tf parser (#726)



* refactor files

* formatting

* fix add_bcast test

* fix some tidy errors

* add transpose field and fix more tidy

* formatting

* add pad test and fix more tidy

* formatting

* fix conv parser

* fix depthwiseconv

* remove unsed functions

* remove includes and functions
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 3d24a21c
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_pad : op_parser<parse_pad>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"Pad"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2 * ndims; i += 2)
{
pad_per_dim[i / 2].first = tf_padding[i];
pad_per_dim[i / 2].second = tf_padding[i + 1];
}
parser.reorder_data(pad_per_dim);
std::vector<int64_t> pads(ndims * 2);
for(size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i + ndims] = pad_per_dim[i].second;
}
return info.add_instruction(make_op("pad", {{"pads", pads}}), args.front());
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_pooling : op_parser<parse_pooling>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"AvgPool"}, {"MaxPool"}}; }
instruction_ref parse(const op_desc& opd,
const tf_parser& parser,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::pooling op{starts_with(opd.tf_name, "Max") ? "max" : "average"};
if(contains(info.attributes, "strides"))
{
std::vector<size_t> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(info.attributes, "ksize"))
{
std::vector<size_t> ksize;
copy(info.attributes.at("ksize").list().i(), std::back_inserter(ksize));
parser.reorder_data(ksize);
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
auto l0 = args[0];
if(contains(info.attributes, "padding"))
{
const std::string& pad_mode = info.attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
}
return info.add_instruction(op, l0);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_relu6 : op_parser<parse_relu6>
{
std::vector<op_desc> operators() const { return {{"Relu6"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = info.add_literal(0.0f);
auto max_val = info.add_literal(6.0f);
min_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_reshape : op_parser<parse_reshape>
{
std::vector<op_desc> operators() const { return {{"Reshape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval();
std::vector<int64_t> dims;
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}),
info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_shape : op_parser<parse_shape>
{
std::vector<op_desc> operators() const { return {{"Shape"}}; }
// Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return info.add_literal(migraphx::literal{s, vec_shape});
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_slice : op_parser<parse_slice>
{
std::vector<op_desc> operators() const { return {{"Slice"}}; }
// Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto size = args[2]->eval().get<int32_t>().to_vector();
auto axes = args[0]->get_shape().lens();
size_t num_axes = axes.size();
std::vector<int64_t> axes_int64(axes.begin(), axes.end());
std::vector<int64_t> starts_int64(starts.begin(), starts.end());
std::vector<int64_t> ends(num_axes);
std::vector<int64_t> op_axes(num_axes);
std::iota(op_axes.begin(), op_axes.end(), 0);
for(size_t i = 0; i < num_axes; i++)
{
if(size[i] == -1)
ends[i] = axes_int64[i];
else
ends[i] = starts_int64[i] + size[i];
}
auto op = make_op("slice", {{"starts", starts_int64}, {"ends", ends}, {"axes", op_axes}});
return info.add_instruction(op, info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_softmax : op_parser<parse_softmax>
{
std::vector<op_desc> operators() const { return {{"Softmax"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
int axis = -1;
auto num_dims = args[0]->get_shape().lens().size();
if(contains(info.attributes, "axis"))
{
axis = static_cast<int>(info.attributes.at("axis").i());
}
axis = tune_axis(num_dims, axis, "tf_parse_softmax");
return info.add_instruction(make_op("softmax", {{"axis", axis}}),
info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_split : op_parser<parse_split>
{
std::vector<op_desc> operators() const { return {{"Split"}, {"SplitV"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
bool vector_as_input = args.size() == 3;
int num_outputs = 1;
auto axis_arg = args[0];
auto input_arg = args[1];
if(vector_as_input)
{
input_arg = args[0];
axis_arg = args[2];
}
if(contains(info.attributes, "num_split"))
num_outputs = info.attributes.at("num_split").i();
std::vector<int> splits(num_outputs);
std::vector<int> slice_pos{0};
if(vector_as_input)
{
splits = args[1]->eval().get<int32_t>().to_vector();
num_outputs = splits.size();
}
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{
info.add_instruction(make_op("identity"), input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
int axis = axis_arg->eval().at<int32_t>();
// ensure split is made evenly if "num_split" is used
assert(vector_as_input or lens[axis] % num_outputs == 0);
auto split_size = lens[axis] / num_outputs;
// push back first end point of slice
if(vector_as_input)
{
slice_pos.push_back(splits[0]);
}
else
{
slice_pos.push_back(split_size);
}
// calculate remaining end points for each slice
for(auto i = 1; i < num_outputs; i++)
{
if(vector_as_input)
{
splits[i] += splits[i - 1];
slice_pos.push_back(splits[i]);
}
else
{
slice_pos.push_back((i + 1) * split_size);
}
}
std::vector<instruction_ref> result;
for(auto i = 0; i < num_outputs; i++)
{
std::vector<int64_t> axes(num_dims);
std::iota(axes.begin(), axes.end(), 0);
std::vector<int64_t> starts(num_dims, 0);
std::vector<int64_t> ends(lens.begin(), lens.end());
starts[axis] = slice_pos[i];
ends[axis] = slice_pos[i + 1];
auto op = make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}});
result.push_back(info.add_instruction(op, input_arg));
}
return result;
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_squeeze : op_parser<parse_squeeze>
{
std::vector<op_desc> operators() const { return {{"Squeeze"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto input_dims = args[0]->get_shape().lens();
auto axes = info.attributes.at("squeeze_dims").list().i();
std::vector<int64_t> op_axes(axes.begin(), axes.end());
if(op_axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(size_t i = 0; i < input_dims.size(); i++)
{
if(input_dims.at(i) == 1)
{
op_axes.push_back(i);
}
}
}
return info.add_instruction(make_op("squeeze", {{"axes", op_axes}}),
info.make_contiguous(args[0]));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_strideslice : op_parser<parse_strideslice>
{
std::vector<op_desc> operators() const { return {{"StridedSlice"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
std::vector<int64_t> op_starts(starts.begin(), starts.end());
std::vector<int64_t> op_ends(ends.begin(), ends.end());
std::vector<int64_t> op_axes(num_axes);
std::iota(op_axes.begin(), op_axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(info.attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(info.attributes.at("begin_mask").i());
if(contains(info.attributes, "end_mask"))
end_mask = static_cast<uint32_t>(info.attributes.at("end_mask").i());
if(contains(info.attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(info.attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op_starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op_ends.at(i) = axes.at(i);
}
}
auto op = make_op("slice", {{"starts", op_starts}, {"ends", op_ends}, {"axes", op_axes}});
auto l1 = info.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
return info.add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_transpose : op_parser<parse_transpose>
{
std::vector<op_desc> operators() const { return {{"Transpose"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> dims(perm.begin(), perm.end());
return info.add_instruction(make_op("transpose", {{"dims", dims}}), args.front());
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <migraphx/tf/tf_parser.hpp>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
module* mm = prog.get_main_module();
bool is_nhwc = true;
unsigned int batch_size = 1;
// Specified dims of inputs
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, op_func> ops;
bool should_transpose(instruction_ref ins) const
{
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
}
instruction_ref make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
return ins;
else
return mm->add_instruction(make_op("contiguous"), ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
return result;
}
std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
return result;
}
std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
if(is_nhwc)
{
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis, num_dims);
});
}
return axes;
}
template <class T>
std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
{
if(is_nhwc)
{
std::vector<T> new_axes;
std::transform(axes.begin(),
axes.end(),
std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis, num_dims); });
return new_axes;
}
return axes;
}
// tf stores certain attributes such as strides, dilations, as a 4D input.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// This helper function reorders the data to store for the respective operator member variables.
template <class T>
void reorder_data(std::vector<T>& prev_data) const
{
std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++)
{
auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
template <class T>
T parse_axis(const T& dim, const size_t num_dims) const
{
T new_dim = dim;
if(is_nhwc and num_dims >= 4)
{
switch(dim)
{
case 0: new_dim = 0; break;
case 1: new_dim = 2; break;
case 2: new_dim = 3; break;
case 3: new_dim = 1; break;
default: break;
}
}
return new_dim;
}
std::vector<int64_t> get_axes(size_t num_axes) const
{
std::vector<int64_t> axes(num_axes);
std::iota(axes.begin(), axes.end(), 0);
return axes;
}
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser()
{
add_generic_op("All", make_op("identity"));
add_generic_op("Identity", make_op("identity"));
add_generic_op("LessEqual", make_op("identity"));
add_generic_op("Relu", make_op("relu"));
add_generic_op("Rsqrt", make_op("rsqrt"));
add_generic_op("Tanh", make_op("tanh"));
add_generic_op("StopGradient", make_op("identity"));
add_binary_op("Add", make_op("add"));
add_binary_op("AddV2", make_op("add"));
add_binary_op("Mul", make_op("mul"));
add_binary_op("Pow", make_op("pow"));
add_binary_op("SquaredDifference", make_op("sqdiff"));
add_binary_op("Sub", make_op("sub"));
add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("FusedBatchNormV3", &tf_parser::parse_batchnorm);
add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Relu6", &tf_parser::parse_relu6);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Shape", &tf_parser::parse_shape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Split", &tf_parser::parse_split, false);
add_mem_op("SplitV", &tf_parser::parse_split, false);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
}
template <class F>
void add_op(const std::string& name, F f, bool transpose = true)
{
if(transpose)
{
ops.emplace(
name,
op_func{
[=](const attribute_map& attributes, const std::vector<instruction_ref>& args) {
return std::vector<instruction_ref>{to_nhwc(f(attributes, to_nchw(args)))};
}});
}
else
{
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
const std::vector<instruction_ref>& args) {
return std::vector<instruction_ref>{f(attributes, args)};
}});
}
}
template <class F>
void add_mem_op(std::string name, F f, bool transpose = true)
{
add_op(name,
[=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
},
transpose);
}
template <class T>
void add_binary_op(std::string name, T x)
{
add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
// TODO
// if(contains(attributes, "data_format"))
// {
// if(is_nhwc)
// {
// l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
// }
return add_broadcastable_binary_op(args[0], args[1], x);
},
false);
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<size_t>* s0 = &arg0->get_shape().lens();
const std::vector<size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg0);
auto l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg1);
return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
}
else
{
return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
}
}
template <class T>
void add_generic_op(std::string name, T x, bool transpose = true)
{
add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
return mm->add_instruction(x, args);
},
transpose);
}
template <class Op>
instruction_ref
parse_arg_op(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
int64_t axis = 0;
axis = args[1]->eval().at<int64_t>();
auto ins = mm->add_instruction(Op{axis}, args.front());
return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
instruction_ref parse_batchnorm(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(attributes, "epsilon"))
{
epsilon = attributes.at("epsilon").f();
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return mm->add_instruction(op, std::move(args));
}
instruction_ref
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = mm->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
return mm->add_instruction(make_op("add"), args[0], l0);
}
instruction_ref parse_cast(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
}
instruction_ref parse_concat(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
// get index for axis within args
size_t axis_idx = attributes.at("N").i();
int64_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return mm->add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&) const
{
literal v = parse_tensor(attributes.at("value").tensor());
return mm->add_literal(v);
}
instruction_ref parse_conv(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
op::convolution op;
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto weights = to_kcxy(args[1]);
auto l0 = args[0];
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
return mm->add_instruction(op, {l0, weights});
}
instruction_ref parse_depthwiseconv(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels;
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
auto weights = to_kcxy(args[1]);
if(contains(attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto l0 = args[0];
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
// weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
// out_channels is equal to the multiplier. Adjust by inserting a reshape and
// setting in_channels to 1
int64_t multiplier = new_weights_shape[0];
int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto new_weights = mm->add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
make_contiguous(weights));
return mm->add_instruction(op, {l0, new_weights});
}
instruction_ref parse_expanddims(const std::string&,
const attribute_map&,
std::vector<instruction_ref> args) const
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return mm->add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
}
instruction_ref
parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
int axis = args[2]->eval().at<int32_t>();
op::gather op{axis};
return mm->add_instruction(op, {args[0], args[1]});
}
instruction_ref parse_matmul(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
bool transa = false;
bool transb = false;
if(contains(attributes, "transpose_a"))
{
transa = attributes.at("transpose_a").b();
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_b").b();
}
if(contains(attributes, "adj_x"))
{
transa = attributes.at("adj_x").b();
}
if(contains(attributes, "adj_y"))
{
transb = attributes.at("adj_y").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
return mm->add_instruction(make_op("dot"), l1, l2);
}
instruction_ref parse_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
bool keep_dims = attributes.at("keep_dims").b();
auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
if(keep_dims)
{
return mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
}
else
{
auto ins = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
}
}
instruction_ref parse_onehot(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = mm->add_literal({s, depth_input});
return mm->add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args) const
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) {
return mm->add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
return to_nhwc(mm->add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
size_t ndims = args.front()->get_shape().lens().size();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
for(size_t i = 0; i < 2 * ndims; i += 2)
{
pad_per_dim[i / 2].first = tf_padding[i];
pad_per_dim[i / 2].second = tf_padding[i + 1];
}
reorder_data(pad_per_dim);
op::pad op;
std::vector<int64_t> pads(ndims * 2);
for(size_t i = 0; i < ndims; i++)
{
pads[i] = pad_per_dim[i].first;
pads[i + ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return mm->add_instruction(op, args.front());
}
instruction_ref parse_pooling(const std::string& name,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
op::pooling op{starts_with(name, "Max") ? "max" : "average"};
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(attributes, "ksize"))
{
std::vector<size_t> ksize;
copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
reorder_data(ksize);
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
auto l0 = args[0];
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
}
return mm->add_instruction(op, l0);
}
instruction_ref
parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
return mm->add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
instruction_ref
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
op::reshape op;
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return mm->add_instruction(op, make_contiguous(args[0]));
}
// Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return mm->add_literal(migraphx::literal{s, vec_shape});
}
instruction_ref
parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto size = args[2]->eval().get<int32_t>().to_vector();
auto axes = args[0]->get_shape().lens();
size_t num_axes = axes.size();
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(num_axes);
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
for(size_t i = 0; i < num_axes; i++)
{
if(size[i] == -1)
op.ends[i] = axes[i];
else
op.ends[i] = starts[i] + size[i];
}
return mm->add_instruction(op, make_contiguous(args[0]));
}
// template to facilitate the logsoftmax later
template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = -1;
auto num_dims = args[0]->get_shape().lens().size();
if(contains(attributes, "axis"))
{
axis = static_cast<int>(attributes.at("axis").i());
}
axis = tune_axis(num_dims, axis, "tf_parse_softmax");
return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
}
std::vector<instruction_ref> parse_split(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args) const
{
bool vector_as_input = args.size() == 3;
int num_outputs = 1;
auto axis_arg = args[0];
auto input_arg = args[1];
if(vector_as_input)
{
input_arg = args[0];
axis_arg = args[2];
}
if(contains(attributes, "num_split"))
num_outputs = attributes.at("num_split").i();
std::vector<int> splits(num_outputs);
std::vector<int> slice_pos{0};
if(vector_as_input)
{
splits = args[1]->eval().get<int32_t>().to_vector();
num_outputs = splits.size();
}
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{
mm->add_instruction(make_op("identity"), input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
int axis = axis_arg->eval().at<int32_t>();
// ensure split is made evenly if "num_split" is used
assert(vector_as_input or lens[axis] % num_outputs == 0);
auto split_size = lens[axis] / num_outputs;
// push back first end point of slice
if(vector_as_input)
{
slice_pos.push_back(splits[0]);
}
else
{
slice_pos.push_back(split_size);
}
// calculate remaining end points for each slice
for(auto i = 1; i < num_outputs; i++)
{
if(vector_as_input)
{
splits[i] += splits[i - 1];
slice_pos.push_back(splits[i]);
}
else
{
slice_pos.push_back((i + 1) * split_size);
}
}
std::vector<instruction_ref> result;
for(auto i = 0; i < num_outputs; i++)
{
op::slice op;
op.axes = std::vector<int64_t>(num_dims);
std::iota(op.axes.begin(), op.axes.end(), 0);
op.starts = std::vector<int64_t>(num_dims, 0);
op.ends = std::vector<int64_t>(lens.begin(), lens.end());
op.starts[axis] = slice_pos[i];
op.ends[axis] = slice_pos[i + 1];
result.push_back(mm->add_instruction(op, input_arg));
}
return result;
}
instruction_ref parse_squeeze(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args) const
{
op::squeeze op;
auto input_dims = args[0]->get_shape().lens();
auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes));
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(size_t i = 0; i < input_dims.size(); i++)
{
if(input_dims.at(i) == 1)
{
op.axes.push_back(i);
}
}
}
return mm->add_instruction(op, make_contiguous(args[0]));
}
instruction_ref parse_stridedslice(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l1 = mm->add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
return mm->add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
}
instruction_ref parse_transpose(const std::string&,
const attribute_map&,
std::vector<instruction_ref> args) const
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return mm->add_instruction(op, args.front());
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
for(auto&& input : input_nodes)
{
const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
if(contains(map_input_dims, name))
{
dims = map_input_dims.at(name);
}
else
{
if(is_nhwc and dims.size() >= 4)
{
reorder_data(dims);
}
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
}
shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(mm->add_parameter(name, s));
}
for(auto&& p : nodes)
{
this->parse_node(p.first);
}
// Needs to add a ret instruction at the end of
// the program
}
void parse_node(const std::string& name)
{
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
// noOps ignored
if(node.op() == "NoOp" or contains(name, "NoOp"))
return;
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0)
{
std::string iname;
// input was from a node with multiple outputs
if(contains(input, ':'))
{
iname = input.substr(0, input.find(':'));
}
else
{
iname = get_name(nodes.at(input));
}
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(input));
}
else
{
args.push_back(instructions.at(input));
}
}
std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0)
{
result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
}
else
{
result = ops[node.op()](get_attributes(node), args);
}
assert(!result.empty());
// First output has no ":" delimiter
instructions[name] = result.front();
for(size_t i = 1; i < result.size(); i++)
{
instructions[name + ":" + std::to_string(i)] = result.at(i);
}
}
}
void parse_from(std::istream& is)
{
tensorflow::GraphDef graph;
if(graph.ParseFromIstream(&is))
{
this->parse_graph(graph);
}
else
{
throw std::runtime_error("Failed reading tf file");
}
}
static attribute_map get_attributes(const tensorflow::NodeDef& node)
{
attribute_map result;
for(auto&& attr : node.attr())
{
result[attr.first] = attr.second;
}
return result;
}
static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
static node_map get_nodes(const tensorflow::GraphDef& graph,
std::vector<tensorflow::NodeDef>& input_nodes)
{
node_map result;
for(auto&& node : graph.node())
{
auto node_name = get_name(node);
// assume each node in graph has an associated name
if(node_name.empty())
MIGRAPHX_THROW("tf node with no name found");
result[node_name] = node;
if(node.op() == "Placeholder")
{
input_nodes.push_back(node);
}
}
return result;
}
static shape::type_t parse_type(const tensorflow::DataType t)
{
shape::type_t shape_type{};
switch(t)
{
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
// tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
}
return shape_type;
}
static literal parse_tensor(const tensorflow::TensorProto& t)
{
std::vector<size_t> dims = parse_dims(t.tensor_shape());
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data
{
const std::string& s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
case tensorflow::DataType::DT_INT16:
return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.dtype())
{
case tensorflow::DataType::DT_FLOAT:
return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16:
return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16:
return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64:
return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF:
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
if(data.size() == 1)
{
std::fill(data_vals.begin(), data_vals.end(), data[0]);
}
else
copy(data.begin(), data.end(), std::back_inserter(data_vals));
return data_vals;
}
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
{
std::vector<size_t> dims;
auto input_dims = s.dim();
std::transform(input_dims.begin(),
input_dims.end(),
std::back_inserter(dims),
[](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
return dims;
}
template <class T>
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
};
program parse_tf(const std::string& name, const tf_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf_parser parser;
tf::tf_parser parser;
parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims;
......
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/tf/op_parser.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
bool tf_parser::should_transpose(instruction_ref ins) const
{
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref tf_parser::to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return ins;
}
instruction_ref tf_parser::to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return ins;
}
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
}
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
return result;
}
std::vector<instruction_ref> tf_parser::to_nhwc(const std::vector<instruction_ref>& args) const
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
return result;
}
instruction_ref tf_parser::node_info::make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
return ins;
else
return mm->add_instruction(make_op("contiguous"), ins);
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref tf_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg1);
return add_instruction(make_op(op_name), l0, l1);
}
else
{
return add_instruction(make_op(op_name), {arg0, arg1});
}
}
int64_t tf_parser::parse_axis(const int64_t dim, const size_t num_dims) const
{
int64_t new_dim = dim;
if(is_nhwc and num_dims >= 4)
{
switch(dim)
{
case 0: new_dim = 0; break;
case 1: new_dim = 2; break;
case 2: new_dim = 3; break;
case 3: new_dim = 1; break;
default: break;
}
}
return new_dim;
}
instruction_ref
tf_parser::node_info::add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const
{
return mm->add_instruction(op, args);
}
instruction_ref tf_parser::node_info::add_literal(literal l) const
{
return mm->add_literal(std::move(l));
}
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser::tf_parser()
{
// Add all registered op parsers
for(auto&& name : get_op_parsers())
ops.emplace(name, get_op_parser(name));
}
static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
static tf_parser::node_map get_nodes(const tensorflow::GraphDef& graph,
std::vector<tensorflow::NodeDef>& input_nodes)
{
tf_parser::node_map result;
for(auto&& node : graph.node())
{
auto node_name = get_name(node);
// assume each node in graph has an associated name
if(node_name.empty())
MIGRAPHX_THROW("tf node with no name found");
result[node_name] = node;
if(node.op() == "Placeholder")
{
input_nodes.push_back(node);
}
}
return result;
}
static tf_parser::attribute_map get_attributes(const tensorflow::NodeDef& node)
{
tf_parser::attribute_map result;
for(auto&& attr : node.attr())
{
result[attr.first] = attr.second;
}
return result;
}
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
{
std::vector<size_t> dims;
auto input_dims = s.dim();
std::transform(input_dims.begin(),
input_dims.end(),
std::back_inserter(dims),
[](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
return dims;
}
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
if(data.size() == 1)
{
std::fill(data_vals.begin(), data_vals.end(), data[0]);
}
else
copy(data.begin(), data.end(), std::back_inserter(data_vals));
return data_vals;
}
template <class T>
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
for(auto&& input : input_nodes)
{
const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
if(contains(map_input_dims, name))
{
dims = map_input_dims.at(name);
}
else
{
if(is_nhwc and dims.size() >= 4)
{
this->reorder_data(dims);
}
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
}
shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(mm->add_parameter(name, s));
}
for(auto&& p : nodes)
{
this->parse_node(p.first);
}
// Needs to add a ret instruction at the end of
// the program
}
void tf_parser::parse_node(const std::string& name)
{
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
// noOps ignored
if(node.op() == "NoOp" or contains(name, "NoOp"))
return;
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0)
{
std::string iname;
// input was from a node with multiple outputs
if(contains(input, ':'))
{
iname = input.substr(0, input.find(':'));
}
else
{
iname = get_name(nodes.at(input));
}
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(input));
}
else
{
args.push_back(instructions.at(input));
}
}
std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0)
{
result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
}
else
{
result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args);
}
assert(!result.empty());
// First output has no ":" delimiter
instructions[name] = result.front();
for(size_t i = 1; i < result.size(); i++)
{
instructions[name + ":" + std::to_string(i)] = result.at(i);
}
}
}
void tf_parser::parse_from(std::istream& is)
{
tensorflow::GraphDef graph;
if(graph.ParseFromIstream(&is))
{
this->parse_graph(graph);
}
else
{
throw std::runtime_error("Failed reading tf file");
}
}
shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const
{
shape::type_t shape_type{};
switch(t)
{
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
// tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
}
return shape_type;
}
literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
{
std::vector<size_t> dims = parse_dims(t.tensor_shape());
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data
{
const std::string& s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_FLOAT: return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
case tensorflow::DataType::DT_INT16: return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.dtype())
{
case tensorflow::DataType::DT_FLOAT:
return create_literal(shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16:
return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16:
return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64:
return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF:
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -338,6 +338,15 @@ def pack_test_nhwc(g1):
tf.stack([g1_input, g2_input, g3_input], axis=3, name='pack1')
@tf_test
def pad_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 4), name='0')
paddings = tf.constant([[1, 1], [2, 2]])
tf.pad(g1_input, paddings, name='pad1')
@tf_test
def pooling_test(g1):
with g1.as_default():
......@@ -579,10 +588,11 @@ if __name__ == '__main__':
mean_test()
mean_test_nhwc()
mul_test()
onehot_test()
noop_test()
onehot_test()
pack_test()
pack_test_nhwc()
pad_test()
pooling_test()
pow_test()
relu_test()
......
......@@ -71,10 +71,8 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l0);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l2, l3);
mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog);
......@@ -546,6 +544,23 @@ TEST_CASE(pack_test_nhwc)
EXPECT(p == prog);
}
TEST_CASE(pad_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
std::vector<int> pad_literals{1, 1, 2, 2};
std::vector<int> pads{1, 2, 1, 2};
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
auto prog = optimize_tf("pad_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pooling_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment