Unverified Commit 1dd4e4d9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor onnx parser (#699)



* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold

* Refactor onnx_parser class

* Formatting

* Add op_parser

* Formatting

* Remove old onnx drivers

* Use file GLOB

* Parse arg ops

* Formatting

* Add pooling

* Formatting

* Add parse_natchnorm

* Add more operators

* Formatting

* Add more operators

* Formatting

* Add more operators

* Formatting

* Add more operators

* Add rnn operators

* Formatting

* Fix tidy issues

* Formatting

* Add back missing param

* Formatting

* Fix shadow variable

* Fix shadow in declaration

* Make global constant

* Formatting

* Add generic op

* Formatting

* Add binary op

* Formatting

* Add variadiac op

* Formatting

* Remove unused fields and functions

* Set default values

* Formatting

* Remove unused member variable

* Add add literal overload

* Use info.add_literal

* Formatting

* Call add_instruction through info class

* Fix tidy issues

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 69d2e38f
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
template <class T>
std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
{
std::vector<std::size_t> indices;
for(std::size_t i = 0; i < data.size(); ++i)
{
if(!float_equal(data[i], 0))
indices.push_back(i);
}
return indices;
}
struct parse_nonzero : op_parser<parse_nonzero>
{
std::vector<op_desc> operators() const { return {{"NonZero"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
migraphx::argument data_arg = args.back()->eval();
check_arg_empty(data_arg, "PARSE_NONZERO: cannot support non-constant input!");
std::vector<std::size_t> indices;
data_arg.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
std::vector<val_type> vec_data;
vec_data.assign(val.begin(), val.end());
indices = nonzero_indices(vec_data);
});
shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}};
std::vector<int64_t> out_data(out_s.elements());
for(std::size_t i = 0; i < indices.size(); ++i)
{
auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j)
{
out_data[out_s.index({j, i})] = idx[j];
}
}
return info.add_literal(literal(out_s, out_data));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_onehot : op_parser<parse_onehot>
{
std::vector<op_desc> operators() const { return {{"OneHot"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
migraphx::argument depth_arg = args[1]->eval();
check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
size_t depth = depth_arg.at<size_t>();
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = info.attributes.at("axis").i();
}
std::vector<float> depth_input(depth * depth, 0.0f);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = 1.0f;
}
auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}};
auto l_val = info.mm->add_literal({s, depth_input});
auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size();
if(axis < -n_rank or axis >= n_rank)
{
MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
}
int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto on_val = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void calc_reflect_indices(std::vector<int>& indices, const int64_t num_dims)
{
int k = 0;
bool reversed = false;
// in reflect padding, if the num_pads > num_dims,
// compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0)
for(int& idx : indices)
{
if(k == num_dims - 1)
reversed = true;
if(k == 0)
reversed = false;
if(reversed)
k--;
else
k++;
idx = k;
}
}
instruction_ref reflect_pad(const onnx_parser::node_info& info,
const std::vector<int64_t>& pads,
instruction_ref input)
{
size_t num_dims = pads.size() / 2;
std::vector<int> ldims(pads.begin(), pads.begin() + num_dims);
std::vector<int> rdims(pads.begin() + num_dims, pads.end());
assert(ldims.size() == rdims.size());
std::vector<int64_t> axes(num_dims);
std::iota(axes.begin(), axes.end(), int64_t{0});
// iterate over dimensions, starting from lowest dimension
for(int64_t i = num_dims - 1; i >= 0; i--)
{
auto axis = i;
auto lcount = ldims.at(i);
auto rcount = rdims.at(i);
if(lcount == 0 and rcount == 0) // no padding for current dim
continue;
// calculate starts and ends for each iteration since shape may change
std::vector<size_t> dims = input->get_shape().lens();
std::vector<int64_t> starts(axes.size(), 0);
std::vector<int64_t> ends(dims.begin(), dims.end());
std::vector<instruction_ref> slices;
auto starts_it = starts.begin() + i;
auto ends_it = ends.begin() + i;
auto dims_it = dims.begin() + i;
std::vector<int> l_indices(lcount);
std::vector<int> r_indices(rcount);
// compute slice indices in a periodic fashion
calc_reflect_indices(l_indices, *dims_it);
calc_reflect_indices(r_indices, *dims_it);
for(int idx : l_indices)
{
*starts_it = idx;
*ends_it = *starts_it + 1;
slices.push_back(info.mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
// when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end());
slices.push_back(input);
for(int idx : r_indices)
{
*starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1;
slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
input = info.mm->add_instruction(make_op("concat", {{"axis", axis}}), slices);
}
return input;
}
struct parse_pad : op_parser<parse_pad>
{
std::vector<op_desc> operators() const { return {{"Pad"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
else
{
MIGRAPHX_THROW("PARSE_PAD: pad must be available");
}
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode == "reflect")
return reflect_pad(info, pads, args.front());
if(mode != "constant")
{
MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding");
}
}
float value = 0.0f;
// third input is the value
if(args.size() == 3)
{
auto val_ins = args.at(2);
if(!val_ins->can_eval())
{
MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
}
auto val_arg = val_ins->eval();
if(val_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
}
value = val_arg.at<float>();
}
else if(contains(info.attributes, "value"))
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front());
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_pooling : op_parser<parse_pooling>
{
std::vector<op_desc> operators() const
{
return {{"AveragePool", "average"},
{"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"},
{"MaxPool", "max"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
operation op = make_op("pooling", {{"mode", mode}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
}
// does not support ceil_mode
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
count_include_pad = info.attributes.at("count_include_pad").i();
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(contains(info.attributes, "auto_pad"))
{
values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_lens,
paddings);
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding(paddings.begin(), paddings.end());
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(!slice_start.empty())
{
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
in_lens = l0->get_shape().lens();
for(size_t i = 0; i < kdims; i++)
{
if(values["lengths"][i].to<int64_t>() >
in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
{
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
}
}
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(!slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_range : op_parser<parse_range>
{
std::vector<op_desc> operators() const { return {{"Range"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto start_arg = args[0]->eval();
check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported");
auto limit_arg = args[1]->eval();
check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported");
auto delta_arg = args[2]->eval();
check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported");
assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and
args[2]->get_shape().elements() == 1);
instruction_ref l0;
visit_all(start_arg, limit_arg, delta_arg)([&](auto start, auto limit, auto delta) {
auto start_val = start.front();
auto limit_val = limit.front();
auto delta_val = delta.front();
size_t num_elements = static_cast<size_t>(
ceil(static_cast<double>(limit_val - start_val) / static_cast<double>(delta_val)));
assert(num_elements > 0);
using type = decltype(start_val);
std::vector<type> range_vals(num_elements);
std::generate(range_vals.begin(), range_vals.end(), [&]() {
auto result = start_val;
start_val += delta_val;
return result;
});
l0 = info.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
});
return l0;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
instruction_ref parse_reduce_oper(const std::string& op_name,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(info.attributes, "axes"))
{
axes.clear();
auto&& attr_axes = info.attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(info.attributes, "keepdims"))
{
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return info.add_instruction(make_op(op_name, {{"axes", axes}}), args);
}
else
{
auto ins = info.add_instruction(make_op(op_name, {{"axes", axes}}), args);
return info.add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
}
}
struct parse_reduce_op : op_parser<parse_reduce_op>
{
std::vector<op_desc> operators() const
{
return {{"ReduceMax", "reduce_max"},
{"ReduceMean", "reduce_mean"},
{"ReduceMin", "reduce_min"},
{"ReduceProd", "reduce_prod"},
{"ReduceSum", "reduce_sum"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return parse_reduce_oper(opd.op_name, parser, std::move(info), std::move(args));
}
};
struct parse_reduce_l1 : op_parser<parse_reduce_l1>
{
std::vector<op_desc> operators() const { return {{"ReduceL1"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto abs_ins = info.add_instruction(make_op("abs"), args[0]);
return parse_reduce_oper("reduce_sum", parser, std::move(info), {abs_ins});
}
};
struct parse_reduce_l2 : op_parser<parse_reduce_l2>
{
std::vector<op_desc> operators() const { return {{"ReduceL2"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]);
auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {square_ins});
return info.add_instruction(make_op("sqrt"), sum_ins);
}
};
struct parse_reduce_log_sum : op_parser<parse_reduce_log_sum>
{
std::vector<op_desc> operators() const { return {{"ReduceLogSum"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, std::move(args));
return info.add_instruction(make_op("log"), sum_ins);
}
};
struct parse_reduce_log_sum_exp : op_parser<parse_reduce_log_sum_exp>
{
std::vector<op_desc> operators() const { return {{"ReduceLogSumExp"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto exp_ins = info.add_instruction(make_op("exp"), args[0]);
auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {exp_ins});
return info.add_instruction(make_op("log"), sum_ins);
}
};
struct parse_reduce_sum_square : op_parser<parse_reduce_sum_square>
{
std::vector<op_desc> operators() const { return {{"ReduceSumSquare"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]);
return parse_reduce_oper("reduce_sum", parser, std::move(info), {square_ins});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_reshape : op_parser<parse_reshape>
{
std::vector<op_desc> operators() const { return {{"Reshape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::vector<int64_t> dims;
if(args.size() == 1)
{
literal s = parser.parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
}
if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
}
return info.add_instruction(make_op("reshape", {{"dims", dims}}),
info.make_contiguous(args[0]));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
const auto& get_nearest_op(const std::string& mode)
{
using nearest_op = std::function<std::size_t(std::size_t, double)>;
static std::unordered_map<std::string, nearest_op> const nearest_ops = {
{"round_prefer_floor",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val - 0.5)));
}},
{"round_prefer_ceil",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::round((val)));
}},
{"floor",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::floor((val)));
}},
{"ceil", [=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val)));
}}};
if(!contains(nearest_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: nearest_mode " + mode + " not supported!");
}
return nearest_ops.at(mode);
}
const auto& get_original_idx_op(const std::string& mode)
{
using original_idx_op = std::function<double(std::size_t, std::size_t, std::size_t, double)>;
static std::unordered_map<std::string, original_idx_op> const idx_ops = {
{"half_pixel",
[=](std::size_t, std::size_t, std::size_t idx, double scale) {
return (idx + 0.5) / scale - 0.5;
}},
{"pytorch_half_pixel",
[=](std::size_t, std::size_t l_out, std::size_t idx, double scale) {
return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0;
}},
{"align_corners",
[=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) {
return 1.0 * idx * (l_in - 1.0) / (l_out - 1.0);
}},
{"asymmetric",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }},
{"tf_half_pixel_for_nn", [=](std::size_t, std::size_t, std::size_t idx, double scale) {
return (idx + 0.5) / scale;
}}};
if(!contains(idx_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: coordinate_transformation_mode " + mode + " not supported!");
}
return idx_ops.at(mode);
}
struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string coord_trans_mode = "half_pixel";
if(contains(info.attributes, "coordinate_transformation_mode"))
{
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s();
// does not support transformation mode "tf_crop_and_resize"
if(coord_trans_mode == "tf_crop_and_resize")
{
MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
}
}
// mode: only nearest mode is supported for now
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest mode is supported!");
}
}
// nearest mode
std::string nearest_mode = "round_prefer_floor";
if(contains(info.attributes, "nearest_mode"))
{
nearest_mode = info.attributes.at("nearest_mode").s();
}
// check exclude_outside, only support 0
if(contains(info.attributes, "exclude_outside"))
{
int exclude_outside = info.attributes.at("exclude_outside").i();
if(exclude_outside == 1)
{
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
}
}
// input data shape info
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
// output shape is explicitly specified
std::vector<std::size_t> out_lens(in_lens.size());
// scale
std::vector<double> vec_scale;
// output size is specified in input, so use it as output size
if(args.size() == 4 and args.back()->name() != "undefined")
{
auto arg_out_s = args[3]->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
// need to compute the output lens from input
else
{
auto arg_scale = args[2]->eval();
check_arg_empty(arg_scale, "PARSE_RESIZE: dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!");
}
std::transform(
in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
}
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode);
auto idx_op = get_original_idx_op(coord_trans_mode);
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], in_idx[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/map_activation_functions.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_rnn : op_parser<parse_rnn>
{
std::vector<op_desc> operators() const { return {{"RNN"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att =
parser.parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(info.attributes, "direction"))
{
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(info.attributes, "activations"))
{
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_activation_functions().count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& fn) { return map_activation_functions().at(fn); });
// To be added later
float clip = 0.0;
if(contains(info.attributes, "clip"))
{
clip = parser.parse_value(info.attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = info.add_instruction(make_op("undefined"));
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("rnn",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip}}),
args);
// second output for the last hidden state
auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_selu : op_parser<parse_selu>
{
std::vector<op_desc> operators() const { return {{"Selu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto type = args[0]->get_shape().type();
auto lens = args[0]->get_shape().lens();
float alpha = 1.67326f;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
float gamma = 1.0507f;
if(contains(info.attributes, "gamma"))
{
gamma = info.attributes.at("gamma").f();
}
auto l_alpha = info.add_literal({{type, {1}}, {alpha}});
auto l_gamma = info.add_literal({{type, {1}}, {gamma / 2.0f}});
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
}
auto sign_x = info.add_instruction(make_op("sign"), args[0]);
auto exp_x = info.add_instruction(make_op("exp"), args[0]);
auto alpha_ex = info.add_instruction(make_op("mul"), l_alpha, exp_x);
auto aex_alpha = info.add_instruction(make_op("sub"), alpha_ex, l_alpha);
auto ins1 = info.add_instruction(make_op("add"), aex_alpha, args[0]);
auto ins2 = info.add_instruction(make_op("sub"), aex_alpha, args[0]);
auto sign2 = info.add_instruction(make_op("mul"), sign_x, ins2);
auto ins_sub = info.add_instruction(make_op("sub"), ins1, sign2);
return info.add_instruction(make_op("mul"), ins_sub, l_gamma);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
struct parse_shape : op_parser<parse_shape>
{
std::vector<op_desc> operators() const { return {{"Shape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return info.add_literal(migraphx::literal{s, vec_shape});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_slice : op_parser<parse_slice>
{
std::vector<op_desc> operators() const { return {{"Slice"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::slice op;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
if(args.size() == 5)
{
migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
std::vector<int> steps;
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
{
MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
}
}
if(args.size() >= 4)
{
migraphx::argument axes_arg = args.at(3)->eval();
check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "axes"))
{
literal s = parser.parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
if(args.size() >= 3)
{
migraphx::argument end_arg = args.at(2)->eval();
check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "ends"))
{
literal s = parser.parse_value(info.attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
}
if(args.size() >= 2)
{
migraphx::argument start_arg = args.at(1)->eval();
check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "starts"))
{
literal s = parser.parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
if(op.axes.empty())
{
std::vector<int64_t> axes(args[0]->get_shape().lens().size());
std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes;
}
return info.add_instruction(op, args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_split : op_parser<parse_split>
{
std::vector<op_desc> operators() const { return {{"Split"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto lens = args[0]->get_shape().lens();
int64_t n_rank = static_cast<int64_t>(lens.size());
if((axis < -n_rank) || (axis >= n_rank))
{
MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
}
int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
// no split attribute, input is equally divided
else
{
if((lens[tuned_axis] % info.num_outputs) != 0)
{
MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
std::to_string(info.num_outputs) + " splits!");
}
auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(info.num_outputs, dl);
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(info.add_instruction(
make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
}
return ret_ins;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_tile : op_parser<parse_tile>
{
std::vector<op_desc> operators() const { return {{"Tile"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
std::vector<std::int64_t> repeats;
arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });
auto l0 = args[0];
for(int i = 0; i < repeats.size(); i++)
{
auto l1 = l0;
for(int j = 1; j < repeats[i]; j++)
{
l0 = info.add_instruction(make_op("concat", {{"axis", i}}), l0, l1);
}
}
return l0;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_transpose : op_parser<parse_transpose>
{
std::vector<op_desc> operators() const { return {{"Transpose"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::vector<int64_t> perm{};
if(contains(info.attributes, "perm"))
{
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_upsample : op_parser<parse_upsample>
{
std::vector<op_desc> operators() const { return {{"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_variadic_op : op_parser<parse_variadic_op>
{
std::vector<op_desc> operators() const
{
return {{"Sum", "add"}, {"Max", "max"}, {"Min", "min"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser&,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[&](instruction_ref a, instruction_ref b) {
return info.add_broadcastable_binary_op(opd.op_name, a, b);
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_where : op_parser<parse_where>
{
std::vector<op_desc> operators() const { return {{"Where"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto cond =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens)
{
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]);
}
// compute index
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <vector>
#include <algorithm>
#include <cmath>
template <typename T>
std::vector<T> softmax(const std::vector<T>& p)
{
size_t n = p.size();
std::vector<T> result(n);
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
T s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<T>());
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment