Unverified Commit d237415f authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Refactor tf parser (#726)



* refactor files

* formatting

* fix add_bcast test

* fix some tidy errors

* add transpose field and fix more tidy

* formatting

* add pad test and fix more tidy

* formatting

* fix conv parser

* fix depthwiseconv

* remove unsed functions

* remove includes and functions
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 3d24a21c
...@@ -19,7 +19,9 @@ target_compile_options(tf-proto PRIVATE -w) ...@@ -19,7 +19,9 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraphx_tf tf.cpp) file(GLOB TF_SRCS *.cpp)
add_library(migraphx_tf ${TF_SRCS})
target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_tf) rocm_clang_tidy_check(migraphx_tf)
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_TF_REGISTER_OP_PARSER_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_TF_REGISTER_OP_PARSER_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct op_desc
{
std::string tf_name = "";
std::string op_name = "";
};
void register_op_parser(const std::string& name, tf_parser::op_func f);
tf_parser::op_func get_op_parser(const std::string& name);
std::vector<std::string> get_op_parsers();
inline std::vector<instruction_ref> implicit_multi_op(std::vector<instruction_ref> inss)
{
return inss;
}
inline std::vector<instruction_ref> implicit_multi_op(instruction_ref ins) { return {ins}; }
template <class T>
void register_op_parser()
{
T parser;
for(auto&& opd : parser.operators())
register_op_parser(opd.tf_name,
[opd, parser](auto&&... xs) { return parser.base_parse(opd, xs...); });
}
struct register_op_parser_action
{
template <class T>
static void apply()
{
register_op_parser<T>();
}
};
template <class Derived>
struct op_parser : auto_register<register_op_parser_action, Derived>
{
bool transpose() const { return false; }
std::vector<instruction_ref> base_parse(const op_desc& opd,
const tf_parser& parser,
tf_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
std::vector<instruction_ref> result;
auto& self = static_cast<const Derived&>(*this);
if(self.transpose())
{
result = implicit_multi_op(self.parse(opd, parser, info, parser.to_nchw(args)));
std::transform(result.begin(), result.end(), result.begin(), [&](auto ins) {
return parser.to_nhwc(ins);
});
}
else
{
result = implicit_multi_op(self.parse(opd, parser, info, args));
}
return result;
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_TF_PARSER_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_TF_PARSER_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <unordered_map>
#include <functional>
#include <utility>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
// namespace tf = tf_for_migraphx;
struct tf_parser
{
std::string filename;
std::string path = ".";
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
struct node_info
{
attribute_map attributes{};
std::string name = "";
module* mm = nullptr;
instruction_ref make_contiguous(instruction_ref ins) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const;
instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const;
template <class... Ts>
instruction_ref add_instruction(const operation& op, Ts... xs) const
{
return add_instruction(op, {xs...});
}
instruction_ref add_literal(literal l) const;
template <class... Ts>
instruction_ref add_literal(Ts&&... xs) const
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
};
using node_map = std::map<std::string, tensorflow::NodeDef>;
using op_func = std::function<std::vector<instruction_ref>(
const tf_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
module* mm = prog.get_main_module();
bool is_nhwc = true;
unsigned int batch_size = 1;
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, op_func> ops;
tf_parser();
operation load(const std::string& name, const node_info& info) const;
bool should_transpose(instruction_ref ins) const;
instruction_ref to_nhwc(instruction_ref ins) const;
instruction_ref to_nchw(instruction_ref ins) const;
instruction_ref to_kcxy(instruction_ref ins) const;
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args) const;
std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args) const;
int64_t parse_axis(int64_t dim, size_t num_dims) const;
// tf stores certain attributes such as strides, dilations, as a 4D input.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// This helper function reorders the data to store for the respective operator member variables.
template <class T>
void reorder_data(std::vector<T>& prev_data) const
{
std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++)
{
auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
void parse_undefined(module* mm, const std::string& name);
void parse_from(std::istream& is);
void parse_from(const void* data, std::size_t size);
void parse_graph(const tensorflow::GraphDef& graph);
void parse_node(const std::string& name);
literal parse_tensor(const tensorflow::TensorProto& t) const;
shape::type_t parse_type(tensorflow::DataType t) const;
};
std::vector<int64_t> get_axes_from_mask(size_t num_axes, uint32_t mask);
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/tf/op_parser.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
std::unordered_map<std::string, tf_parser::op_func>& op_parser_map()
{
static std::unordered_map<std::string, tf_parser::op_func> m; // NOLINT
return m;
}
void register_op_parser(const std::string& name, tf_parser::op_func f)
{
op_parser_map()[name] = std::move(f);
}
tf_parser::op_func get_op_parser(const std::string& name) { return op_parser_map().at(name); }
std::vector<std::string> get_op_parsers()
{
std::vector<std::string> result;
std::transform(op_parser_map().begin(),
op_parser_map().end(),
std::back_inserter(result),
[&](auto&& p) { return p.first; });
return result;
}
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_arg_op : op_parser<parse_arg_op>
{
std::vector<op_desc> operators() const { return {{"ArgMax", "argmax"}, {"ArgMin", "argmin"}}; }
instruction_ref parse(const op_desc& opd,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
int64_t axis = 0;
axis = args[1]->eval().at<int64_t>();
auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args.front());
return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_batchnorm : op_parser<parse_batchnorm>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"FusedBatchNorm"}, {"FusedBatchNormV3"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
if(contains(info.attributes, "epsilon"))
{
epsilon = info.attributes.at("epsilon").f();
}
auto op = make_op("batch_norm_inference", {{"epsilon", epsilon}, {"momentum", momentum}});
return info.add_instruction(op, args);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_biasadd : op_parser<parse_biasadd>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"BiasAdd"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
return info.add_instruction(make_op("add"), args[0], l0);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_binary_op : op_parser<parse_binary_op>
{
std::vector<op_desc> operators() const
{
return {{"Add", "add"},
{"AddV2", "add"},
{"Mul", "mul"},
{"Pow", "pow"},
{"SquaredDifference", "sqdiff"},
{"Sub", "sub"}};
}
instruction_ref parse(const op_desc& opd,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_cast : op_parser<parse_cast>
{
std::vector<op_desc> operators() const { return {{"Cast"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
tf_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
shape::type_t type = parser.parse_type(info.attributes.at("DstT").type());
return info.add_instruction(make_op("convert", {{"target_type", type}}), args);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_concat : op_parser<parse_concat>
{
std::vector<op_desc> operators() const { return {{"ConcatV2"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
// get index for axis within args
size_t axis_idx = info.attributes.at("N").i();
int64_t axis = args[axis_idx]->eval().at<int64_t>();
auto op = make_op("concat", {{"axis", axis}});
// return only first N arguments (assuming last index is the axis value)
return info.add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_constant_op : op_parser<parse_constant_op>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"Const"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
tf_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const
{
literal v = parser.parse_tensor(info.attributes.at("value").tensor());
return info.add_literal(v);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_conv : op_parser<parse_conv>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"Conv2D"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::convolution op;
if(contains(info.attributes, "strides"))
{
std::vector<size_t> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(info.attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
parser.reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto weights = parser.to_kcxy(args[1]);
auto l0 = args[0];
if(contains(info.attributes, "padding"))
{
const std::string& pad_mode = info.attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(info.attributes.at("explicit_paddings").list().i(),
std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
return info.add_instruction(op, {l0, weights});
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
{
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"DepthwiseConv2dNative"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels;
if(contains(info.attributes, "strides"))
{
std::vector<size_t> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
auto weights = parser.to_kcxy(args[1]);
if(contains(info.attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
parser.reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto l0 = args[0];
if(contains(info.attributes, "padding"))
{
const std::string& pad_mode = info.attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
// weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
// out_channels is equal to the multiplier. Adjust by inserting a reshape and
// setting in_channels to 1
int64_t multiplier = new_weights_shape[0];
int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto new_weights = info.add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
info.make_contiguous(weights));
return info.add_instruction(op, {l0, new_weights});
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_expanddims : op_parser<parse_expanddims>
{
std::vector<op_desc> operators() const { return {{"ExpandDims"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return info.add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_gather : op_parser<parse_gather>
{
std::vector<op_desc> operators() const { return {{"GatherV2"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int axis = args[2]->eval().at<int32_t>();
return info.add_instruction(make_op("gather", {{"axis", axis}}), {args[0], args[1]});
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_generic_op : op_parser<parse_generic_op>
{
std::vector<op_desc> operators() const
{
return {{"All", "identity"},
{"Identity", "identity"},
{"LessEqual", "identity"},
{"Relu", "relu"},
{"Rsqrt", "rsqrt"},
{"Tanh", "tanh"},
{"StopGradient", "identity"}};
}
instruction_ref parse(const op_desc& opd,
const tf_parser& /*parser*/,
const tf_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
return info.add_instruction(make_op(opd.op_name), args);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_matmul : op_parser<parse_matmul>
{
std::vector<op_desc> operators() const
{
return {{"BatchMatMul"}, {"BatchMatMulV2"}, {"MatMul"}};
}
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
bool transa = false;
bool transb = false;
if(contains(info.attributes, "transpose_a"))
{
transa = info.attributes.at("transpose_a").b();
}
if(contains(info.attributes, "transpose_b"))
{
transb = info.attributes.at("transpose_b").b();
}
if(contains(info.attributes, "adj_x"))
{
transa = info.attributes.at("adj_x").b();
}
if(contains(info.attributes, "adj_y"))
{
transb = info.attributes.at("adj_y").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
return info.add_instruction(make_op("dot"), l1, l2);
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_mean : op_parser<parse_mean>
{
std::vector<op_desc> operators() const { return {{"Mean"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
bool keep_dims = info.attributes.at("keep_dims").b();
auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
auto ins = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
if(not keep_dims)
ins = info.add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
return ins;
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_onehot : op_parser<parse_onehot>
{
std::vector<op_desc> operators() const { return {{"OneHot"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = info.add_literal({s, depth_input});
return info.add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {
struct parse_pack : op_parser<parse_pack>
{
std::vector<op_desc> operators() const { return {{"Pack"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& parser,
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) {
return info.add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
return parser.to_nhwc(
info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
}
};
} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment