Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void calc_reflect_indices(std::vector<int>& indices, const int64_t num_dims)
{
int k = 0;
bool reversed = false;
// in reflect padding, if the num_pads > num_dims,
// compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0)
for(int& idx : indices)
{
if(k == num_dims - 1)
reversed = true;
if(k == 0)
reversed = false;
if(reversed)
k--;
else
k++;
idx = k;
}
}
instruction_ref reflect_pad(const onnx_parser::node_info& info,
const std::vector<int64_t>& pads,
instruction_ref input)
{
size_t num_dims = pads.size() / 2;
std::vector<int> ldims(pads.begin(), pads.begin() + num_dims);
std::vector<int> rdims(pads.begin() + num_dims, pads.end());
assert(ldims.size() == rdims.size());
std::vector<int64_t> axes(num_dims);
std::iota(axes.begin(), axes.end(), int64_t{0});
// iterate over dimensions, starting from lowest dimension
for(int64_t i = num_dims - 1; i >= 0; i--)
{
auto axis = i;
auto lcount = ldims.at(i);
auto rcount = rdims.at(i);
if(lcount == 0 and rcount == 0) // no padding for current dim
continue;
// calculate starts and ends for each iteration since shape may change
std::vector<size_t> dims = input->get_shape().lens();
std::vector<int64_t> starts(axes.size(), 0);
std::vector<int64_t> ends(dims.begin(), dims.end());
std::vector<instruction_ref> slices;
auto starts_it = starts.begin() + i;
auto ends_it = ends.begin() + i;
auto dims_it = dims.begin() + i;
std::vector<int> l_indices(lcount);
std::vector<int> r_indices(rcount);
// compute slice indices in a periodic fashion
calc_reflect_indices(l_indices, *dims_it);
calc_reflect_indices(r_indices, *dims_it);
for(int idx : l_indices)
{
*starts_it = idx;
*ends_it = *starts_it + 1;
slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
// when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end());
slices.push_back(input);
for(int idx : r_indices)
{
*starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1;
slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
input = info.add_instruction(make_op("concat", {{"axis", axis}}), slices);
}
return input;
}
struct parse_pad : op_parser<parse_pad>
{
std::vector<op_desc> operators() const { return {{"Pad"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
else
{
MIGRAPHX_THROW("PARSE_PAD: pad must be available");
}
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode == "reflect")
return reflect_pad(info, pads, args.front());
if(mode != "constant")
{
MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding");
}
}
float value = 0.0f;
// third input is the value
if(args.size() == 3)
{
auto val_ins = args.at(2);
if(!val_ins->can_eval())
{
MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
}
auto val_arg = val_ins->eval();
if(val_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
}
value = val_arg.at<float>();
}
else if(contains(info.attributes, "value"))
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front());
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_pooling : op_parser<parse_pooling>
{
std::vector<op_desc> operators() const
{
return {{"AveragePool", "average"},
{"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"},
{"MaxPool", "max"},
{"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name;
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
}
// does not support ceil_mode
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
count_include_pad = info.attributes.at("count_include_pad").i();
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(contains(info.attributes, "auto_pad"))
{
values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_lens,
paddings);
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(!slice_start.empty())
{
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(!slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
auto compute_type(shape::type_t t1, shape::type_t t2)
{
const static std::unordered_map<int, int> op_order = {{shape::int8_type, 1},
{shape::uint8_type, 2},
{shape::int16_type, 3},
{shape::uint16_type, 4},
{shape::int32_type, 5},
{shape::uint32_type, 6},
{shape::int64_type, 7},
{shape::uint64_type, 8},
{shape::half_type, 9},
{shape::float_type, 10},
{shape::double_type, 11}};
int it1 = t1;
int it2 = t2;
if(!contains(op_order, it1) or !contains(op_order, it2))
{
MIGRAPHX_THROW("PARSE_POW: Input data type not supported!");
}
return ((op_order.at(it1) >= op_order.at(it2)) ? t1 : t2);
}
struct parse_pow : op_parser<parse_pow>
{
std::vector<op_desc> operators() const { return {{"Pow"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto type_base = args[0]->get_shape().type();
auto type_exponent = args[1]->get_shape().type();
auto type_compute = compute_type(type_base, type_exponent);
if(type_compute != type_base)
{
args[0] =
info.add_instruction(make_op("convert", {{"target_type", type_compute}}), args[0]);
}
if(type_compute != type_exponent)
{
args[1] =
info.add_instruction(make_op("convert", {{"target_type", type_compute}}), args[1]);
}
auto ret = info.add_broadcastable_binary_op("pow", args[0], args[1]);
if(type_compute != type_base)
{
ret = info.add_instruction(make_op("convert", {{"target_type", type_base}}), ret);
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
instruction_ref parse_prefix_scan_oper(const std::string& op_name,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args)
{
migraphx::argument in = args[1]->eval();
check_arg_empty(in, "PARSE_PREFIX_SCAN: axis - dynamic shape not supported");
std::vector<std::size_t> axis_in;
in.visit([&](auto input) { axis_in.assign(input.begin(), input.end()); });
int64_t axis = axis_in[0];
bool exclusive = false;
bool reverse = false;
if(contains(info.attributes, "exclusive"))
{
exclusive = parser.parse_value(info.attributes.at("exclusive")).at<bool>();
}
if(contains(info.attributes, "reverse"))
{
reverse = parser.parse_value(info.attributes.at("reverse")).at<bool>();
}
return info.add_instruction(
make_op(op_name, {{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}),
args[0]);
}
struct parse_prefix_scan_op : op_parser<parse_prefix_scan_op>
{
std::vector<op_desc> operators() const { return {{"CumSum", "prefix_scan_sum"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return parse_prefix_scan_oper(opd.op_name, parser, std::move(info), std::move(args));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
int axis = 1;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size();
instruction_ref y_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3)
{
auto y_zero_point = args[2];
if(y_zero_point->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point);
}
else
{
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <random>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
{
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int dtype = 1;
bool use_dtype = false;
if(contains(info.attributes, "dtype"))
{
dtype = info.attributes.at("dtype").i();
use_dtype = true;
}
shape::type_t out_type = get_type(dtype);
if(not contains(valid_types, out_type))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
float mean = 0.0;
if(contains(info.attributes, "mean"))
mean = info.attributes.at("mean").f();
float scale = 1.0;
if(contains(info.attributes, "scale"))
scale = info.attributes.at("scale").f();
shape out_shape;
if(contains(info.attributes, "shape"))
{
// RandomNormal:
// output type and shape must come from attributes
std::vector<int> out_lens;
literal ls = parser.parse_value(info.attributes.at("shape"));
ls.visit([&](auto s) { out_lens.assign(s.begin(), s.end()); });
out_shape = shape{out_type, out_lens};
}
else if(args.size() == 1)
{
// RandomNormalLike:
// output type and shape are the same as the input's by default
// dtype is used instead when attribute is set
if(not contains(valid_types, args[0]->get_shape().type()))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " +
std::to_string(args[0]->get_shape().type()) +
". Valid types are float, half, and double.");
out_shape =
use_dtype ? shape{out_type, args[0]->get_shape().lens()} : args[0]->get_shape();
}
else
{
MIGRAPHX_THROW(opd.op_name +
": cannot deduce shape without shape attribute or argument.");
}
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::normal_distribution<> d(mean, scale);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
return info.add_literal(literal{out_shape, rand_vals});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <random>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
{
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int dtype = 1;
bool use_dtype = false;
if(contains(info.attributes, "dtype"))
{
dtype = info.attributes.at("dtype").i();
use_dtype = true;
}
shape::type_t out_type = get_type(dtype);
if(not contains(valid_types, out_type))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
float high = 1.0;
if(contains(info.attributes, "high"))
high = info.attributes.at("high").f();
float low = 0.0;
if(contains(info.attributes, "low"))
low = info.attributes.at("low").f();
shape out_shape;
if(contains(info.attributes, "shape"))
{
// RandomUniform:
// output type and shape must come from attributes
std::vector<int> out_lens;
literal ls = parser.parse_value(info.attributes.at("shape"));
ls.visit([&](auto s) { out_lens.assign(s.begin(), s.end()); });
out_shape = shape{out_type, out_lens};
}
else if(args.size() == 1)
{
// RandomUniformLike:
// output type and shape are the same as the input by default
// dtype is used instead when attribute is set
if(not contains(valid_types, args[0]->get_shape().type()))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " +
std::to_string(args[0]->get_shape().type()) +
". Valid types are float, half, and double.");
out_shape =
use_dtype ? shape{out_type, args[0]->get_shape().lens()} : args[0]->get_shape();
}
else
{
MIGRAPHX_THROW(opd.op_name +
": cannot deduce shape without shape attribute or argument.");
}
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
return info.add_literal(literal{out_shape, rand_vals});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_range : op_parser<parse_range>
{
std::vector<op_desc> operators() const { return {{"Range"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto start_arg = args[0]->eval();
check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported");
auto limit_arg = args[1]->eval();
check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported");
auto delta_arg = args[2]->eval();
check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported");
assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and
args[2]->get_shape().elements() == 1);
instruction_ref l0;
visit_all(start_arg, limit_arg, delta_arg)([&](auto start, auto limit, auto delta) {
auto start_val = start.front();
auto limit_val = limit.front();
auto delta_val = delta.front();
size_t num_elements = static_cast<size_t>(
ceil(static_cast<double>(limit_val - start_val) / static_cast<double>(delta_val)));
assert(num_elements > 0);
using type = decltype(start_val);
std::vector<type> range_vals(num_elements);
std::generate(range_vals.begin(), range_vals.end(), [&]() {
auto result = start_val;
start_val += delta_val;
return result;
});
l0 = info.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
});
return l0;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
instruction_ref parse_reduce_oper(const std::string& op_name,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args)
{
// default to reduce over all dimensions
std::vector<int64_t> axes;
if(args.size() == 2)
{
auto arg_axes = args.at(1)->eval();
check_arg_empty(arg_axes, "PARSE_" + op_name + ": cannot handle variable axes!");
axes.clear();
arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "axes"))
{
axes.clear();
auto&& attr_axes = info.attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
bool noop_with_empty_axes = false;
if(contains(info.attributes, "noop_with_empty_axes"))
{
noop_with_empty_axes = static_cast<bool>(
parser.parse_value(info.attributes.at("noop_with_empty_axes")).at<int>());
}
// empty axes behavior
if(axes.empty())
{
if(noop_with_empty_axes)
{
return args.at(0);
}
else
{
std::size_t n_dim = args.front()->get_shape().lens().size();
axes.resize(n_dim);
std::iota(axes.begin(), axes.end(), 0);
}
}
int keep_dims = 1;
if(contains(info.attributes, "keepdims"))
{
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return info.add_instruction(make_op(op_name, {{"axes", axes}}), args.front());
}
else
{
auto ins = info.add_instruction(make_op(op_name, {{"axes", axes}}), args.front());
return info.add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
}
}
struct parse_reduce_op : op_parser<parse_reduce_op>
{
std::vector<op_desc> operators() const
{
return {{"ReduceMax", "reduce_max"},
{"ReduceMean", "reduce_mean"},
{"ReduceMin", "reduce_min"},
{"ReduceProd", "reduce_prod"},
{"ReduceSum", "reduce_sum"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return parse_reduce_oper(opd.op_name, parser, std::move(info), std::move(args));
}
};
struct parse_reduce_l1 : op_parser<parse_reduce_l1>
{
std::vector<op_desc> operators() const { return {{"ReduceL1"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto abs_ins = info.add_instruction(make_op("abs"), args[0]);
return parse_reduce_oper("reduce_sum", parser, std::move(info), {abs_ins});
}
};
struct parse_reduce_l2 : op_parser<parse_reduce_l2>
{
std::vector<op_desc> operators() const { return {{"ReduceL2"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]);
auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {square_ins});
return info.add_instruction(make_op("sqrt"), sum_ins);
}
};
struct parse_reduce_log_sum : op_parser<parse_reduce_log_sum>
{
std::vector<op_desc> operators() const { return {{"ReduceLogSum"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, std::move(args));
return info.add_instruction(make_op("log"), sum_ins);
}
};
struct parse_reduce_log_sum_exp : op_parser<parse_reduce_log_sum_exp>
{
std::vector<op_desc> operators() const { return {{"ReduceLogSumExp"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto exp_ins = info.add_instruction(make_op("exp"), args[0]);
auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {exp_ins});
return info.add_instruction(make_op("log"), sum_ins);
}
};
struct parse_reduce_sum_square : op_parser<parse_reduce_sum_square>
{
std::vector<op_desc> operators() const { return {{"ReduceSumSquare"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]);
return parse_reduce_oper("reduce_sum", parser, std::move(info), {square_ins});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_reshape : op_parser<parse_reshape>
{
std::vector<op_desc> operators() const { return {{"Reshape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::vector<int64_t> dims;
if(args.size() == 1)
{
literal s = parser.parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
}
if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
}
return info.add_instruction(make_op("reshape", {{"dims", dims}}),
info.make_contiguous(args[0]));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
const auto& get_nearest_op(const std::string& mode)
{
using nearest_op = std::function<std::size_t(std::size_t, double)>;
static std::unordered_map<std::string, nearest_op> const nearest_ops = {
{"round_prefer_floor",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val - 0.5)));
}},
{"round_prefer_ceil",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::round((val)));
}},
{"floor",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::floor((val)));
}},
{"ceil", [=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val)));
}}};
if(!contains(nearest_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: nearest_mode " + mode + " not supported!");
}
return nearest_ops.at(mode);
}
const auto& get_original_idx_op(const std::string& mode)
{
using original_idx_op = std::function<double(std::size_t, std::size_t, std::size_t, double)>;
static std::unordered_map<std::string, original_idx_op> const idx_ops = {
{"half_pixel",
[=](std::size_t, std::size_t, std::size_t idx, double scale) {
return (idx + 0.5) / scale - 0.5;
}},
{"pytorch_half_pixel",
[=](std::size_t, std::size_t l_out, std::size_t idx, double scale) {
return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0;
}},
{"align_corners",
[=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) {
return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0));
}},
{"asymmetric",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }},
{"tf_half_pixel_for_nn", [=](std::size_t, std::size_t, std::size_t idx, double scale) {
return (idx + 0.5) / scale;
}}};
if(!contains(idx_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: coordinate_transformation_mode " + mode + " not supported!");
}
return idx_ops.at(mode);
}
static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims,
const shape& in_s)
{
if(i_dim == vvv_ind.size())
{
std::vector<int> vec_ind;
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx));
});
return vec_ind;
}
const auto& vv_ind = vvv_ind[i_dim];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
std::transform(vv_lo.begin(),
vv_lo.end(),
vec_dims.begin() + start,
std::back_inserter(vec_dims1),
[](auto i, auto dim) {
dim.push_back(i);
return dim;
});
}
const auto& vv_hi = vv_ind.at(1);
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
std::transform(vv_hi.begin(),
vv_hi.end(),
vec_dims.begin() + start,
std::back_inserter(vec_dims1),
[](auto i, auto dim) {
dim.push_back(i);
return dim;
});
}
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s);
}
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
{
std::string coord_trans_mode = "half_pixel";
if(contains(attr, "coordinate_transformation_mode"))
{
coord_trans_mode = attr.at("coordinate_transformation_mode").s();
// does not support transformation mode "tf_crop_and_resize"
if(coord_trans_mode == "tf_crop_and_resize")
{
MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
}
}
return coord_trans_mode;
}
static std::string get_mode(const onnx_parser::attribute_map& attr)
{
std::string mode = "nearest";
if(contains(attr, "mode"))
{
mode = attr.at("mode").s();
if(mode != "nearest" and mode != "linear")
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest and linear modes are supported!");
}
}
return mode;
}
static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
{
std::string nearest_mode = "round_prefer_floor";
if(contains(attr, "nearest_mode"))
{
nearest_mode = attr.at("nearest_mode").s();
}
return nearest_mode;
}
struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
// coord transform mode
std::string coord_trans_mode = get_coord_trans_mode(info.attributes);
// mode: only nearest and linear modes are supported for now
std::string mode = get_mode(info.attributes);
// nearest mode
std::string nearest_mode = get_nearest_mode(info.attributes);
// check exclude_outside, only support 0
if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!");
}
// input data shape info
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
// output shape is explicitly specified
std::vector<std::size_t> out_lens(in_lens.size());
// scale
std::vector<double> vec_scale;
for(const auto& arg : args)
{
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
}
shape out_s{in_s.type(), out_lens};
std::size_t out_elements = out_s.elements();
auto idx_op = get_original_idx_op(coord_trans_mode);
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest")
{
std::vector<int> ind(out_elements);
// map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
shape ind_s{shape::int32_type, out_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
// linear mode
else
{
auto nearest_floor = get_nearest_op("floor");
auto nearest_ceil = get_nearest_op("ceil");
// get the number of dimensions
std::size_t n_dim = out_lens.size();
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
}
});
std::vector<std::vector<std::size_t>> vec_dims(out_elements);
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s);
auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
auto dim_lens = out_lens;
dim_lens[0] *= (std::size_t{1} << (n_dim - 1));
for(std::size_t i = 0; i < n_dim; ++i)
{
shape dim_s{shape::float_type, dim_lens};
const auto& dim_delta = delta[n_dim - i - 1];
std::vector<float> delta_data;
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j)
{
delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end());
}
auto ins_delta = info.add_literal(dim_s, delta_data);
// slice the data
int64_t slc_stride = dim_lens[0];
auto low = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {slc_stride}}}),
data);
auto hi = info.add_instruction(
make_op("slice",
{{"axes", {0}}, {"starts", {slc_stride}}, {"ends", {2 * slc_stride}}}),
data);
auto diff = info.add_instruction(make_op("sub"), hi, low);
auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta);
data = info.add_instruction(make_op("add"), ddf, low);
dim_lens[0] /= 2;
}
return data;
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct parse_reversesequence : op_parser<parse_reversesequence>
{
std::vector<op_desc> operators() const { return {{"ReverseSequence"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int batch_axis = 1;
if(contains(info.attributes, "batch_axis"))
{
batch_axis = info.attributes.at("batch_axis").i();
}
if(batch_axis != 0 and batch_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1");
}
int time_axis = 0;
if(contains(info.attributes, "time_axis"))
{
time_axis = info.attributes.at("time_axis").i();
}
if(time_axis != 0 and time_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1");
}
if(time_axis == batch_axis)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same");
}
auto input = args[0];
auto input_lens = input->get_shape().lens();
if(input_lens.size() < 2)
{
MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2");
}
std::vector<int64_t> sequence_lens;
if(args.size() == 2)
{
migraphx::argument seq_lens_arg = args.back()->eval();
check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens");
seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "sequence_lens"))
{
literal s = parser.parse_value(info.attributes.at("sequence_lens"));
s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); });
}
auto batch_size = input_lens[batch_axis];
auto time_size = input_lens[time_axis];
// this condition may still work if sequence_len's shape was incorrect
if(sequence_lens.size() != batch_size)
{
MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape");
}
instruction_ref ret;
auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) {
return info.add_instruction(make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b, t_start}},
{"ends", {b + 1, t_end}}}),
input);
};
for(int b = 0; b < batch_size; ++b)
{
instruction_ref s0;
if(sequence_lens[b] > 1)
{
s0 = add_slice(b, 0, sequence_lens[b]);
s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0);
// if reversed less than whole batch, concat rest of batch
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, sequence_lens[b], time_size);
s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1);
}
}
else
{ // cases where nothing changes
s0 = add_slice(b, 0, time_size);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/map_activation_functions.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_rnn : op_parser<parse_rnn>
{
std::vector<op_desc> operators() const { return {{"RNN"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att =
parser.parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(info.attributes, "direction"))
{
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(info.attributes, "activations"))
{
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_activation_functions().count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& fn) { return map_activation_functions().at(fn); });
// To be added later
float clip = 0.0;
if(contains(info.attributes, "clip"))
{
clip = parser.parse_value(info.attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = info.add_instruction(make_op("undefined"));
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("rnn",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip}}),
args);
// second output for the last hidden state
auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_roialign : op_parser<parse_roialign>
{
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
std::string coord_trans_mode = "half_pixel";
if(contains(info.attributes, "coordinate_transformation_mode"))
{
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s();
}
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
"\": invalid value!");
}
migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode"))
{
// read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
}
int64_t output_height = 1;
if(contains(info.attributes, "output_height"))
{
output_height = info.attributes.at("output_height").i();
}
int64_t output_width = 1;
if(contains(info.attributes, "output_width"))
{
output_width = info.attributes.at("output_width").i();
}
int64_t sampling_ratio = 0;
if(contains(info.attributes, "sampling_ratio"))
{
sampling_ratio = info.attributes.at("sampling_ratio").i();
}
float spatial_scale = 1.0f;
if(contains(info.attributes, "spatial_scale"))
{
spatial_scale = info.attributes.at("spatial_scale").f();
}
return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode},
{"mode", rmode},
{"output_height", output_height},
{"output_width", output_width},
{"sampling_ratio", sampling_ratio},
{"spatial_scale", spatial_scale}}),
args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatter : op_parser<parse_scatter>
{
std::vector<op_desc> operators() const { return {{"ScatterElements"}, {"Scatter"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
operation op;
std::string op_name = "scatter_none";
int axis = 0;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(contains(info.attributes, "reduction"))
{
std::string reduction_att(info.attributes.at("reduction").s());
// check for a valid reduction attribute. We have an operator for each one.
if(not contains({"none", "add", "mul"}, reduction_att))
MIGRAPHX_THROW("PARSE_SCATTER: unsupported reduction mode " + reduction_att);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name = std::string("scatter_") + reduction_att;
}
op = migraphx::make_op(op_name, {{"axis", axis}});
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_selu : op_parser<parse_selu>
{
std::vector<op_desc> operators() const { return {{"Selu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto type = args[0]->get_shape().type();
auto lens = args[0]->get_shape().lens();
float alpha = 1.67326f;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
float gamma = 1.0507f;
if(contains(info.attributes, "gamma"))
{
gamma = info.attributes.at("gamma").f();
}
auto l_alpha = info.add_literal({{type, {1}}, {alpha}});
auto l_gamma = info.add_literal({{type, {1}}, {gamma / 2.0f}});
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
}
auto sign_x = info.add_instruction(make_op("sign"), args[0]);
auto exp_x = info.add_instruction(make_op("exp"), args[0]);
auto alpha_ex = info.add_instruction(make_op("mul"), l_alpha, exp_x);
auto aex_alpha = info.add_instruction(make_op("sub"), alpha_ex, l_alpha);
auto ins1 = info.add_instruction(make_op("add"), aex_alpha, args[0]);
auto ins2 = info.add_instruction(make_op("sub"), aex_alpha, args[0]);
auto sign2 = info.add_instruction(make_op("mul"), sign_x, ins2);
auto ins_sub = info.add_instruction(make_op("sub"), ins1, sign2);
return info.add_instruction(make_op("mul"), ins_sub, l_gamma);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
struct parse_shape : op_parser<parse_shape>
{
std::vector<op_desc> operators() const { return {{"Shape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return info.add_literal(migraphx::literal{s, vec_shape});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_size : op_parser<parse_size>
{
std::vector<op_desc> operators() const { return {{"Size"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
return info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type},
{args[0]->get_shape().elements()}});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_slice : op_parser<parse_slice>
{
std::vector<op_desc> operators() const { return {{"Slice"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::slice op;
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
if(args.size() == 5)
{
migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
}
if(args.size() >= 4)
{
migraphx::argument axes_arg = args.at(3)->eval();
check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "axes"))
{
literal s = parser.parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
if(args.size() >= 3)
{
migraphx::argument end_arg = args.at(2)->eval();
check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "ends"))
{
literal s = parser.parse_value(info.attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
}
if(args.size() >= 2)
{
migraphx::argument start_arg = args.at(1)->eval();
check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "starts"))
{
literal s = parser.parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
if(op.axes.empty())
{
std::vector<int64_t> axes(args[0]->get_shape().lens().size());
std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes;
}
std::vector<int64_t> raxes;
assert(steps.empty() or steps.size() == op.axes.size());
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
for(auto i : range(steps.size()))
{
if(steps[i] >= 0)
continue;
op.starts[i] += 1;
if(op.starts[i] == 0)
op.starts[i] = INT_MAX;
op.ends[i] += 1;
raxes.push_back(op.axes[i]);
std::swap(op.starts[i], op.ends[i]);
}
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment