Unverified Commit 002eb4e2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add C++ ability to construct operators by name (#616)



* Add make_op function

* Formatting

* Add more values

* Formatting

* Remove templates parse_conv functions

* Formatting

* Remove mat_mul template

* Formatting

* Reduce header includes

* Fix compiling for gpu

* Formatting

* Use make_op in lowering

* Formatting

* Sort lines

* Formatting

* Add more tests

* Formatting

* Fix tidy error

* Formatting

* Add const refs

* Add explicit this

* Add more const refs

* Sort the program

* Remove commented out code

* Formatting

* Infer gpu prefix

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 56b3bf58
......@@ -20,6 +20,7 @@ add_library(migraphx
env.cpp
generate.cpp
instruction.cpp
make_op.cpp
msgpack.cpp
program.cpp
quantization.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -23,7 +23,7 @@ struct quant_dot
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta"));
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "quant_dot"; }
......
......@@ -47,7 +47,7 @@ value to_value_impl(rank<1>, const std::pair<T, U>& x)
template <class T>
auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
{
value result;
value result = value::array{};
for(auto&& y : x)
{
auto e = to_value(y);
......@@ -59,7 +59,7 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
value to_value_impl(rank<3>, const T& x)
{
value result;
value result = value::object{};
reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); });
return result;
}
......
......@@ -73,6 +73,13 @@ struct value_converter<std::pair<T, U>>
};
namespace detail {
template <class To, class Key, class From>
auto try_convert_value_impl(rank<2>, const std::pair<Key, From>& x)
-> decltype(value_converter<To>::apply(x.second))
{
return value_converter<To>::apply(x.second);
}
template <class To, class From>
auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x))
{
......@@ -89,7 +96,7 @@ To try_convert_value_impl(rank<0>, const From& x)
template <class To, class From>
To try_convert_value(const From& x)
{
return detail::try_convert_value_impl<To>(rank<1>{}, x);
return detail::try_convert_value_impl<To>(rank<2>{}, x);
}
struct value
......@@ -159,6 +166,26 @@ struct value
using is_pickable =
std::integral_constant<bool, (std::is_arithmetic<T>{} and not std::is_pointer<T>{})>;
template <class T>
using range_value = std::decay_t<decltype(std::declval<T>().end(), *std::declval<T>().begin())>;
template <class T>
using is_generic_range =
std::integral_constant<bool,
(std::is_convertible<range_value<T>, value>{} and
not std::is_convertible<T, array>{} and
not std::is_convertible<T, object>{})>;
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const T& r) : value(from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const std::string& pkey, const T& r) : value(pkey, from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})>
value(T i) : value(pick<T>{i})
{
......@@ -176,6 +203,11 @@ struct value
{
return *this = pick<T>{rhs}; // NOLINT
}
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value& operator=(T rhs)
{
return *this = from_values(rhs); // NOLINT
}
value& operator=(std::nullptr_t);
......@@ -214,6 +246,10 @@ struct value
const value& operator[](std::size_t i) const;
value& operator[](const std::string& pkey);
void clear();
void resize(std::size_t n);
void resize(std::size_t n, const value& v);
std::pair<value*, bool> insert(const value& v);
value* insert(const value* pos, const value& v);
......@@ -294,6 +330,14 @@ struct value
void debug_print(bool show_type = false) const;
private:
template <class T>
std::vector<value> from_values(const T& r)
{
std::vector<value> v;
std::transform(
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const;
std::shared_ptr<value_base_impl> x;
std::string key;
......
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
{
auto op = load_op(name);
// Merge values
value w = op.to_value();
for(auto&& x : v)
{
w.at(x.get_key()) = x.without_key();
}
op.from_value(w);
return op;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -11,7 +11,7 @@
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
......@@ -20,6 +20,31 @@
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -50,63 +75,67 @@ struct onnx_parser
onnx_parser()
{
// sort onnx operator alphabetically through name
add_generic_op("Abs", op::abs{});
add_generic_op("Acos", op::acos{});
add_generic_op("Acosh", op::acosh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Asinh", op::asinh{});
add_generic_op("Atan", op::atan{});
add_generic_op("Atanh", op::atanh{});
add_generic_op("Ceil", op::ceil{});
add_generic_op("Cos", op::cos{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Erf", op::erf{});
add_generic_op("Exp", op::exp{});
add_generic_op("Dropout", op::identity{});
add_generic_op("Floor", op::floor{});
add_generic_op("Identity", op::identity{});
add_generic_op("Log", op::log{});
add_generic_op("Neg", op::neg{});
add_generic_op("Reciprocal", op::recip{});
add_generic_op("Relu", op::relu{});
add_generic_op("Round", op::round{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Sign", op::sign{});
add_generic_op("Sin", op::sin{});
add_generic_op("Sinh", op::sinh{});
add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Tan", op::tan{});
add_generic_op("Tanh", op::tanh{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Pow", op::pow{});
add_binary_op("PRelu", op::prelu{});
add_binary_op("Sub", op::sub{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_generic_op("Abs", "abs");
add_generic_op("Acos", "acos");
add_generic_op("Acosh", "acosh");
add_generic_op("Asin", "asin");
add_generic_op("Asinh", "asinh");
add_generic_op("Atan", "atan");
add_generic_op("Atanh", "atanh");
add_generic_op("Ceil", "ceil");
add_generic_op("Concat", "concat");
add_generic_op("Cos", "cos");
add_generic_op("Cosh", "cosh");
add_generic_op("Dropout", "identity");
add_generic_op("Erf", "erf");
add_generic_op("Exp", "exp");
add_generic_op("Flatten", "flatten");
add_generic_op("Floor", "floor");
add_generic_op("Gather", "gather", true);
add_generic_op("Identity", "identity");
add_generic_op("Log", "log");
add_generic_op("LogSoftmax", "logsoftmax");
add_generic_op("Neg", "neg");
add_generic_op("Reciprocal", "recip");
add_generic_op("Relu", "relu");
add_generic_op("Round", "round");
add_generic_op("Sigmoid", "sigmoid");
add_generic_op("Sign", "sign");
add_generic_op("Sin", "sin");
add_generic_op("Sinh", "sinh");
add_generic_op("Softmax", "softmax");
add_generic_op("Sqrt", "sqrt");
add_generic_op("Squeeze", "squeeze", true);
add_generic_op("Tan", "tan");
add_generic_op("Tanh", "tanh");
add_generic_op("Unsqueeze", "unsqueeze", true);
add_binary_op("Add", "add");
add_binary_op("Div", "div");
add_binary_op("Mul", "mul");
add_binary_op("Pow", "pow");
add_binary_op("PRelu", "prelu");
add_binary_op("Sub", "sub");
add_variadic_op("Sum", "add");
add_variadic_op("Max", "max");
add_variadic_op("Min", "min");
add_mem_op("ATen", &onnx_parser::parse_aten);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("ArgMax", "argmax", &onnx_parser::parse_arg_op);
add_mem_op("ArgMin", "argmin", &onnx_parser::parse_arg_op);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
add_mem_op("Conv", "convolution", &onnx_parser::parse_conv);
add_mem_op("ConvInteger", "quant_convolution", &onnx_parser::parse_conv);
add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("GatherElements", &onnx_parser::parse_gather_elements);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
......@@ -115,11 +144,10 @@ struct onnx_parser
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("MatMul", "dot", &onnx_parser::parse_matmul);
add_mem_op("MatMulInteger", "quant_dot", &onnx_parser::parse_matmul);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("NonZero", &onnx_parser::parse_nonzero);
add_mem_op("OneHot", &onnx_parser::parse_onehot);
......@@ -129,22 +157,19 @@ struct onnx_parser
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMax", "reduce_max", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceMean", "reduce_mean", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceMin", "reduce_min", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceProd", "reduce_prod", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceSum", "reduce_sum", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
// init the activation function map
init_actv_func();
......@@ -153,11 +178,39 @@ struct onnx_parser
void init_actv_func()
{
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
map_actv_funcs.insert(std::make_pair("tanh", make_op("tanh")));
map_actv_funcs.insert(std::make_pair("relu", make_op("relu")));
map_actv_funcs.insert(std::make_pair("sigmoid", make_op("sigmoid")));
map_actv_funcs.insert(std::make_pair("leakyrelu", make_op("leaky_relu")));
map_actv_funcs.insert(std::make_pair("elu", make_op("elu")));
}
static operation load(const std::string& name, const node_info& info)
{
auto op = make_op(name);
auto v = op.to_value();
for(auto&& x : v)
{
if(info.attributes.count(x.get_key()) == 0)
continue;
literal s = parse_value(info.attributes.at(x.get_key()));
if(x.is_array())
{
std::vector<value> values;
s.visit([&](auto y) {
std::transform(y.begin(), y.end(), std::back_inserter(values), [](auto z) {
return value(z);
});
});
x = values;
}
else
{
s.visit([&](auto y) { x = y.front(); });
}
}
op.from_value(v);
return op;
}
template <class F>
......@@ -176,17 +229,24 @@ struct onnx_parser
}
template <class F>
void add_mem_op(std::string name, F f)
void add_mem_op(const std::string& name, F f)
{
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_binary_op(std::string name, T x)
template <class F>
void add_mem_op(const std::string& onnx_name, const std::string& op_name, F f)
{
add_op(onnx_name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, onnx_name, op_name, std::forward<decltype(xs)>(xs)...);
});
}
void add_binary_op(const std::string& onnx_name, const std::string& op_name)
{
add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
add_op(onnx_name, [this, op_name](node_info info, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
......@@ -197,13 +257,13 @@ struct onnx_parser
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]);
return prog.add_instruction(x, args[0], l);
return prog.add_instruction(make_op(op_name), args[0], l);
}
return prog.add_instruction(x, args);
return prog.add_instruction(make_op(op_name), args);
}
else
{
return add_broadcastable_binary_op(args[0], args[1], x);
return add_broadcastable_binary_op(args[0], args[1], op_name);
}
});
}
......@@ -254,11 +314,11 @@ struct onnx_parser
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
return prog.add_instruction(make_op("contiguous"), ins);
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
instruction_ref
add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, const std::string& name)
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
......@@ -275,31 +335,40 @@ struct onnx_parser
if(arg1->get_shape().lens() != out_lens)
l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1);
return prog.add_instruction(make_op(name), l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
return prog.add_instruction(make_op(name), {arg0, arg1});
}
}
template <class T>
void add_generic_op(std::string name, T x)
void add_generic_op(const std::string& onnx_name,
const std::string& op_name,
bool contiguous = false)
{
add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
add_op(
onnx_name,
[this, op_name, contiguous](const node_info& info, std::vector<instruction_ref> args) {
auto op = load(op_name, info);
if(contiguous)
{
std::transform(args.begin(), args.end(), args.begin(), [&](auto arg) {
return this->make_contiguous(arg);
});
}
return prog.add_instruction(op, args);
});
}
template <class T>
void add_variadic_op(std::string name, T x)
void add_variadic_op(const std::string& onnx_name, const std::string& op_name)
{
add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
add_op(onnx_name, [this, op_name](const node_info&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[this, x](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, x);
[this, op_name](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, op_name);
});
});
}
......@@ -318,7 +387,7 @@ struct onnx_parser
{
auto bias_bcast =
prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
return prog.add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
}
......@@ -338,10 +407,9 @@ struct onnx_parser
return false;
}
template <class Op>
void check_asym_padding(instruction_ref& ins,
const std::vector<int64_t>& padding,
Op& op,
value& v,
int count_include_pad = 0,
float pad_val = 0)
{
......@@ -360,7 +428,7 @@ struct onnx_parser
}
else
{
op.padding = std::vector<size_t>(left_pad_it, right_pad_it);
v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it);
}
}
......@@ -404,29 +472,17 @@ struct onnx_parser
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);
if(min_used and max_used)
return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
return prog.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
if(min_used)
return prog.add_instruction(op::max{}, args[0], min_arg);
return prog.add_instruction(op::identity{}, args[0]);
}
template <class Op>
instruction_ref
parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 1;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
return prog.add_instruction(make_op("max"), args[0], min_arg);
return prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(make_op("identity"), args[0]);
}
template <class Op>
instruction_ref
parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref parse_arg_op(const std::string&,
const std::string& op_name,
node_info info,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
......@@ -442,12 +498,12 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(Op{axis}, std::move(args));
auto ins = prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
}
}
......@@ -536,29 +592,27 @@ struct onnx_parser
}
}
template <class Op>
void recalc_conv_attributes(Op& op, size_t kdims)
void recalc_conv_attributes(value& v, size_t kdims)
{
if(op.padding.size() != kdims)
if(v["padding"].size() != kdims)
{
op.padding.resize(kdims);
std::fill_n(op.padding.begin(), kdims, 0);
v["padding"].resize(kdims);
std::fill_n(v["padding"].begin(), kdims, 0);
}
if(op.stride.size() != kdims)
if(v["stride"].size() != kdims)
{
op.stride.resize(kdims);
std::fill_n(op.stride.begin(), kdims, 1);
v["stride"].resize(kdims);
std::fill_n(v["stride"].begin(), kdims, 1);
}
if(op.dilation.size() != kdims)
if(v["dilation"].size() != kdims)
{
op.dilation.resize(kdims);
std::fill_n(op.dilation.begin(), kdims, 1);
v["dilation"].resize(kdims);
std::fill_n(v["dilation"].begin(), kdims, 1);
}
}
template <class Op>
static void cal_auto_padding_size(node_info info,
Op& op,
value& v,
const std::vector<std::size_t>& k_lens,
const std::vector<std::size_t>& dilation,
const std::vector<std::size_t>& in_lens,
......@@ -575,7 +629,7 @@ struct onnx_parser
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
v["padding_mode"] = to_value(op::padding_mode_t::same);
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
paddings.resize(2 * kdims);
......@@ -584,7 +638,7 @@ struct onnx_parser
calculate_padding(i,
paddings,
in_lens[i + 2],
op.stride[i],
v["stride"][i].to<int64_t>(),
dilation[i],
k_lens[i],
is_same_upper);
......@@ -606,11 +660,13 @@ struct onnx_parser
}
}
template <class Op>
instruction_ref
parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref parse_conv(const std::string&,
const std::string& op_name,
node_info info,
std::vector<instruction_ref> args)
{
Op op;
auto op = make_op(op_name);
auto values = op.to_value();
auto l0 = args[0];
auto weights = args[1];
auto in_lens = l0->get_shape().lens();
......@@ -622,21 +678,22 @@ struct onnx_parser
if(contains(info.attributes, "strides"))
{
op.stride.clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV: inconsistent strides");
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
op.dilation.clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation));
check_attr_sizes(kdims, op.dilation.size(), "PARSE_CONV: inconsistent dilations");
values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations");
}
std::vector<int64_t> padding;
if(contains(info.attributes, "pads"))
{
op.padding.clear();
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
}
......@@ -646,17 +703,23 @@ struct onnx_parser
auto weight_lens = weights->get_shape().lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info, op, k_lens, op.dilation, in_lens, padding);
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
}
check_asym_padding(l0, padding, op);
check_asym_padding(l0, padding, values);
if(contains(info.attributes, "group"))
{
op.group = parse_value(info.attributes.at("group")).at<int>();
values["group"] = parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(op, kdims);
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1);
}
......@@ -664,7 +727,9 @@ struct onnx_parser
instruction_ref
parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::deconvolution op;
operation op = make_op("deconvolution");
value values = op.to_value();
// op::deconvolution op;
auto l0 = args[0];
std::vector<std::int64_t> padding;
bool asym_padding = false;
......@@ -685,25 +750,26 @@ struct onnx_parser
{
size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
op.padding.clear();
values["padding"].clear();
std::transform(padding.begin(),
padding.begin() + pad_ndims,
std::back_inserter(op.padding),
std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; });
}
}
if(contains(info.attributes, "strides"))
{
op.stride.clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
op.dilation.clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation));
values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(
kdims, op.dilation.size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
}
if(contains(info.attributes, "auto_pad"))
{
......@@ -716,17 +782,18 @@ struct onnx_parser
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
values["padding_mode"] = to_value(op::padding_mode_t::same);
}
}
if(contains(info.attributes, "group"))
{
op.group = parse_value(info.attributes.at("group")).at<int>();
values["group"] = parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(op, kdims);
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
......@@ -799,13 +866,13 @@ struct onnx_parser
}
}
static void tune_padding_size(const op::pooling& op,
static void tune_padding_size(const value& v,
std::vector<int64_t>& padding,
int count_include_pad,
std::vector<int64_t>& s_start)
{
// maxpooling or count_include_pad is 1, no change is required.
if(op.mode == "max" or count_include_pad == 1)
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
{
return;
}
......@@ -821,22 +888,25 @@ struct onnx_parser
s_start.resize(n_dims);
for(std::size_t i = 0; i < n_dims; ++i)
{
tune_padding_to_symmetric(padding[i], padding[i + n_dims], op.stride[i], s_start[i]);
tune_padding_to_symmetric(
padding[i], padding[i + n_dims], v.at("stride")[i].to<int64_t>(), s_start[i]);
}
}
instruction_ref
parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
{
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
std::string mode = ends_with(name, "MaxPool") ? "max" : "average";
operation op = make_op("pooling", {{"mode", mode}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(name, "Global"))
{
op.lengths = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
}
// does not support ceil_mode
......@@ -858,25 +928,26 @@ struct onnx_parser
if(contains(info.attributes, "strides"))
{
op.stride.clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
check_attr_sizes(kdims, op.stride.size(), "PARSE_POOLING: inconsistent strides");
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
op.lengths.clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(op.lengths));
check_attr_sizes(kdims, op.lengths.size(), "PARSE_POOLING: inconsistent lengths");
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
std::vector<int64_t> paddings;
float pad_val = ((op.mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
op.padding.clear();
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
......@@ -884,9 +955,14 @@ struct onnx_parser
if(contains(info.attributes, "auto_pad"))
{
op.padding.clear();
values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info, op, op.lengths, {1, 1}, in_lens, paddings);
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_lens,
paddings);
}
if(paddings.size() != 2 * kdims)
......@@ -895,23 +971,23 @@ struct onnx_parser
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(op.padding.size() != kdims)
if(values["padding"].size() != kdims)
{
op.padding.resize(kdims);
std::fill_n(op.padding.begin(), kdims, 0);
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(op.stride.size() != kdims)
if(values["stride"].size() != kdims)
{
op.stride.resize(kdims);
std::fill_n(op.stride.begin(), kdims, 1);
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding(paddings.begin(), paddings.end());
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(op, paddings, count_include_pad, slice_start);
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(!slice_start.empty())
{
......@@ -920,7 +996,7 @@ struct onnx_parser
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
auto out_lens = op.compute_shape({padded_shape}).lens();
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
......@@ -931,16 +1007,17 @@ struct onnx_parser
[](auto i, auto j) { return i + j; });
}
check_asym_padding(l0, paddings, op, count_include_pad, pad_val);
check_asym_padding(l0, paddings, values, count_include_pad, pad_val);
in_lens = l0->get_shape().lens();
for(size_t i = 0; i < kdims; i++)
{
if(op.lengths[i] > in_lens[i + 2] + 2 * op.padding[i])
if(values["lengths"][i].to<int64_t>() >
in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
{
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
}
}
op.from_value(values);
auto l1 = prog.add_instruction(op, l0);
if(!slice_start.empty())
{
......@@ -971,62 +1048,6 @@ struct onnx_parser
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 1;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::flatten{axis}, args[0]);
}
instruction_ref
parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
{
// change to hande axis to be negative values
if(!contains(info.attributes, "axis"))
{
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
}
int axis = parse_value(info.attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
}
instruction_ref
parse_gather_elements(const std::string&, node_info info, std::vector<instruction_ref> args)
{
......@@ -1077,9 +1098,9 @@ struct onnx_parser
auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
auto dim_diff = prog.add_instruction(op::sub{}, arg_ind, l_dim_idx);
auto delta = prog.add_instruction(op::mul{}, dim_diff, l_stride);
auto ind = prog.add_instruction(op::add{}, l_shape_idx, delta);
auto dim_diff = prog.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = prog.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = prog.add_instruction(make_op("add"), l_shape_idx, delta);
op::gather op{0};
return prog.add_instruction(op, arg_data, ind);
......@@ -1214,16 +1235,18 @@ struct onnx_parser
{
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
return prog.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
return prog.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
}
template <class Op>
instruction_ref
parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
instruction_ref parse_matmul(const std::string&,
const std::string& op_name,
const node_info&,
std::vector<instruction_ref> args)
{
auto l0 = args[0];
auto l1 = args[1];
......@@ -1270,7 +1293,8 @@ struct onnx_parser
}
}
auto dot_res = prog.add_instruction(Op{1, 0}, bl0, bl1);
auto dot_res =
prog.add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
......@@ -1326,22 +1350,22 @@ struct onnx_parser
auto bias = args[2];
auto dims = x->get_shape().lens();
auto mean = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
auto mean = prog.add_instruction(make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean);
auto l0 = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
auto variance = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
auto l1 = prog.add_instruction(op::sub{}, x, mean_bcast);
auto l0 = prog.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = prog.add_instruction(make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = prog.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = prog.add_literal(epsilon);
auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = prog.add_instruction(op::multibroadcast{dims}, variance);
auto l2 = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
auto l3 = prog.add_instruction(op::rsqrt{}, l2);
auto l4 = prog.add_instruction(op::mul{}, l1, l3);
auto l2 = prog.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = prog.add_instruction(make_op("rsqrt"), l2);
auto l4 = prog.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = prog.add_instruction(op::broadcast{1, dims}, scale);
;
auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
auto l5 = prog.add_instruction(op::mul{}, l4, scale_bcast);
return prog.add_instruction(op::add{}, l5, bias_bcast);
auto l5 = prog.add_instruction(make_op("mul"), l4, scale_bcast);
return prog.add_instruction(make_op("add"), l5, bias_bcast);
}
instruction_ref
......@@ -1352,7 +1376,7 @@ struct onnx_parser
{
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
op::leaky_relu op{alpha};
auto op = make_op("leaky_relu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front());
}
......@@ -1363,7 +1387,7 @@ struct onnx_parser
{
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
op::elu op{alpha};
auto op = make_op("elu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front());
}
......@@ -1408,9 +1432,10 @@ struct onnx_parser
auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto img_scaled =
prog.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
instruction_ref
......@@ -1447,7 +1472,7 @@ struct onnx_parser
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return prog.add_instruction(migraphx::op::identity{}, args.front());
return prog.add_instruction(make_op("identity"), args.front());
}
if(contains(info.attributes, "mode"))
......@@ -2039,9 +2064,10 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output};
}
template <class T>
instruction_ref
parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref parse_reduce_oper(const std::string&,
const std::string& op_name,
node_info info,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
......@@ -2063,11 +2089,11 @@ struct onnx_parser
if(keep_dims == 1)
{
return prog.add_instruction(T{axes}, std::move(args));
return prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
}
else
{
auto ins = prog.add_instruction(T{axes}, std::move(args));
auto ins = prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
......@@ -2075,38 +2101,38 @@ struct onnx_parser
instruction_ref
parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
auto abs_ins = prog.add_instruction(make_op("abs"), args[0]);
return parse_reduce_oper({}, "reduce_sum", std::move(info), {abs_ins});
}
instruction_ref
parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
return prog.add_instruction(op::sqrt{}, sum_ins);
auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
return prog.add_instruction(make_op("sqrt"), sum_ins);
}
instruction_ref
parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
return prog.add_instruction(op::log{}, sum_ins);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), std::move(args));
return prog.add_instruction(make_op("log"), sum_ins);
}
instruction_ref
parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
return prog.add_instruction(op::log{}, sum_ins);
auto exp_ins = prog.add_instruction(make_op("exp"), args[0]);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {exp_ins});
return prog.add_instruction(make_op("log"), sum_ins);
}
instruction_ref
parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]);
return parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
}
instruction_ref
......@@ -2214,11 +2240,11 @@ struct onnx_parser
auto off_val = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto diff = prog.add_instruction(op::sub{}, on_val, off_val);
auto diff = prog.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = prog.add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
auto l_mul = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val);
return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
auto l_mul = prog.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return prog.add_instruction(make_op("add"), l_mul, unsq_off_val);
}
instruction_ref
......@@ -2302,9 +2328,15 @@ struct onnx_parser
auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]);
switch(reduce_mode)
{
case reduce_mode_t::sum: l0 = prog.add_instruction(op::reduce_sum{{0}}, l0); break;
case reduce_mode_t::mean: l0 = prog.add_instruction(op::reduce_mean{{0}}, l0); break;
case reduce_mode_t::max: l0 = prog.add_instruction(op::reduce_max{{0}}, l0); break;
case reduce_mode_t::sum:
l0 = prog.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::mean:
l0 = prog.add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::max:
l0 = prog.add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
break;
}
return l0;
}
......
#include <rocblas.h>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp>
#include <migraphx/gpu/sign.hpp>
#include <migraphx/gpu/cos.hpp>
#include <migraphx/gpu/tan.hpp>
#include <migraphx/gpu/sinh.hpp>
#include <migraphx/gpu/cosh.hpp>
#include <migraphx/gpu/tanh.hpp>
#include <migraphx/gpu/asin.hpp>
#include <migraphx/gpu/acos.hpp>
#include <migraphx/gpu/atan.hpp>
#include <migraphx/gpu/asinh.hpp>
#include <migraphx/gpu/acosh.hpp>
#include <migraphx/gpu/atanh.hpp>
#include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/max.hpp>
#include <migraphx/gpu/min.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/ceil.hpp>
#include <migraphx/gpu/floor.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_max.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/reduce_min.hpp>
#include <migraphx/gpu/reduce_prod.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/prelu.hpp>
#include <migraphx/gpu/recip.hpp>
#include <migraphx/gpu/rnn_variable_seq_lens.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/iterator_for.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -136,61 +95,59 @@ struct miopen_apply
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu);
add_generic_op<hip_add>("add");
add_generic_op<hip_sub>("sub");
add_generic_op<hip_exp>("exp");
add_generic_op<hip_erf>("erf");
add_generic_op<hip_log>("log");
add_generic_op<hip_sin>("sin");
add_generic_op<hip_cos>("cos");
add_generic_op<hip_tan>("tan");
add_generic_op<hip_sinh>("sinh");
add_generic_op<hip_cosh>("cosh");
add_generic_op<hip_tanh>("tanh");
add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan");
add_generic_op<hip_asinh>("asinh");
add_generic_op<hip_acosh>("acosh");
add_generic_op<hip_atanh>("atanh");
add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_round>("round");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_relu>("relu");
add_generic_op<hip_prelu>("prelu");
add_generic_op<hip_sign>("sign");
add_generic_op<hip_sigmoid>("sigmoid");
add_generic_op<hip_ceil>("ceil");
add_generic_op<hip_floor>("floor");
add_generic_op<hip_recip>("recip");
add_generic_op<miopen_contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
add_extend_op<hip_argmax, op::argmax>("argmax");
add_extend_op<hip_argmin, op::argmin>("argmin");
add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad");
add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_max, op::reduce_max>("reduce_max");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_extend_op<hip_reduce_min, op::reduce_min>("reduce_min");
add_extend_op<hip_reduce_prod, op::reduce_prod>("reduce_prod");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_rnn_var_sl_shift_output, op::rnn_var_sl_shift_output>(
"rnn_var_sl_shift_output");
add_extend_op<hip_rnn_var_sl_shift_sequence, op::rnn_var_sl_shift_sequence>(
"rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_last_output, op::rnn_var_sl_last_output>(
"rnn_var_sl_last_output");
add_generic_op("acos");
add_generic_op("acosh");
add_generic_op("add");
add_generic_op("asin");
add_generic_op("asinh");
add_generic_op("atan");
add_generic_op("atanh");
add_generic_op("ceil");
add_generic_op("contiguous");
add_generic_op("cos");
add_generic_op("cosh");
add_generic_op("div");
add_generic_op("erf");
add_generic_op("exp");
add_generic_op("floor");
add_generic_op("log");
add_generic_op("max");
add_generic_op("min");
add_generic_op("mul");
add_generic_op("pow");
add_generic_op("prelu");
add_generic_op("recip");
add_generic_op("relu");
add_generic_op("round");
add_generic_op("rsqrt");
add_generic_op("sigmoid");
add_generic_op("sign");
add_generic_op("sin");
add_generic_op("sinh");
add_generic_op("sqdiff");
add_generic_op("sqrt");
add_generic_op("sub");
add_generic_op("tan");
add_generic_op("tanh");
add_extend_op("argmax");
add_extend_op("argmin");
add_extend_op("clip");
add_extend_op("concat");
add_extend_op("convert");
add_extend_op("gather");
add_extend_op("logsoftmax");
add_extend_op("pad");
add_extend_op("reduce_max");
add_extend_op("reduce_mean");
add_extend_op("reduce_min");
add_extend_op("reduce_prod");
add_extend_op("reduce_sum");
add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("softmax");
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op();
......@@ -379,28 +336,30 @@ struct miopen_apply
});
}
template <class T>
void add_generic_op(std::string name)
void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); }
void add_generic_op(const std::string& op_name, const std::string& gpu_name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return prog->replace_instruction(ins, T{}, refs);
return prog->replace_instruction(ins, make_op(gpu_name), refs);
});
}
template <class T, class Op>
void add_extend_op(std::string name)
void add_extend_op(const std::string& name) { add_extend_op(name, "gpu::" + name); }
void add_extend_op(const std::string& op_name, const std::string& gpu_name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator());
apply_map.emplace(op_name, [=](instruction_ref ins) {
auto&& op = ins->get_operator();
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return prog->replace_instruction(ins, T{op}, refs);
return prog->replace_instruction(ins, make_op(gpu_name, op.to_value()), refs);
});
}
......@@ -472,7 +431,8 @@ struct miopen_apply
std::vector<float> zeros(s.elements(), 0.0f);
auto l0 = prog->add_literal(literal(s, zeros));
auto output = insert_allocation(ins, s);
return prog->replace_instruction(ins, hip_sub{}, l0, ins->inputs().front(), output);
return prog->replace_instruction(
ins, make_op("gpu::sub"), l0, ins->inputs().front(), output);
});
}
};
......
......@@ -209,6 +209,14 @@ std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x)
return *a;
}
std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
{
auto* a = if_array_impl(x);
if(a == nullptr)
MIGRAPHX_THROW("Expected an array or object");
return *a;
}
value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key)
{
auto* a = if_array_impl(x);
......@@ -302,15 +310,29 @@ const value& value::at(const std::string& pkey) const
{
auto* r = find(pkey);
if(r == nullptr)
MIGRAPHX_THROW("Not an object");
MIGRAPHX_THROW("Not an object for field: " + pkey);
if(r == end())
MIGRAPHX_THROW("Key not found");
MIGRAPHX_THROW("Key not found: " + pkey);
return *r;
}
value& value::operator[](std::size_t i) { return *(begin() + i); }
const value& value::operator[](std::size_t i) const { return *(begin() + i); }
value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; }
void value::clear() { get_array_throw(x).clear(); }
void value::resize(std::size_t n)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n);
}
void value::resize(std::size_t n, const value& v)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n, v);
}
std::pair<value*, bool> value::insert(const value& v)
{
if(v.key.empty())
......
......@@ -44,6 +44,9 @@ function(add_test_command NAME EXE)
# --args $<TARGET_FILE:${EXE}> ${ARGN})
set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME})
file(MAKE_DIRECTORY ${TEST_DIR})
if (NOT EXISTS ${TEST_DIR})
message(FATAL_ERROR "Failed to create test directory: ${TEST_DIR}")
endif()
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT "
# Remove previous core dump
......
......@@ -2014,7 +2014,7 @@ TEST_CASE(transpose_gather_test)
auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p == prog);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(undefined_test)
......
#include <migraphx/register_op.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <sstream>
#include <string>
#include "test.hpp"
......@@ -13,6 +15,35 @@ TEST_CASE(load_op)
}
}
TEST_CASE(make_op)
{
for(const auto& name : migraphx::get_operators())
{
auto op = migraphx::load_op(name);
CHECK(op == migraphx::make_op(name));
}
}
TEST_CASE(make_op_from_value1)
{
migraphx::operation x = migraphx::make_op(
"convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}, {2, 2}, {2, 2}};
EXPECT(x == y);
}
TEST_CASE(make_op_from_value2)
{
migraphx::operation x = migraphx::make_op("convolution", {{"padding", {1, 1}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}};
EXPECT(x == y);
}
TEST_CASE(make_op_invalid_key)
{
EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); }));
}
TEST_CASE(ops)
{
auto names = migraphx::get_operators();
......
......@@ -69,4 +69,35 @@ TEST_CASE(serialize_reflectable_type)
EXPECT(v2 != v3);
}
TEST_CASE(serialize_empty_array)
{
std::vector<std::size_t> ints = {};
migraphx::value v = migraphx::to_value(ints);
EXPECT(v.is_array());
EXPECT(v.empty());
v.push_back(1);
EXPECT(v.size() == 1);
EXPECT(v.front().to<int>() == 1);
}
struct empty_struct
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return migraphx::pack();
}
};
TEST_CASE(serialize_empty_struct)
{
empty_struct es{};
migraphx::value v = migraphx::to_value(es);
EXPECT(v.is_object());
EXPECT(v.empty());
v["a"] = 1;
EXPECT(v.size() == 1);
EXPECT(v.at("a").to<int>() == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -563,4 +563,132 @@ TEST_CASE(print)
EXPECT(ss.str() == "{1, {one: 1, two: 2}, {1, 2}, null}");
}
TEST_CASE(value_clear)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.clear();
EXPECT(values.empty());
values.push_back(3);
EXPECT(values.size() == 1);
EXPECT(values.at(0).to<int>() == 3);
}
TEST_CASE(value_clear_non_array)
{
migraphx::value values = 1.0;
EXPECT(test::throws([&] { values.clear(); }));
}
TEST_CASE(value_clear_object)
{
migraphx::value values = {{"a", 1}, {"b", 2}};
EXPECT(values.is_object());
EXPECT(values.size() == 2);
values.clear();
EXPECT(values.empty());
values["c"] = 3;
EXPECT(values.size() == 1);
EXPECT(values.at("c").to<int>() == 3);
}
TEST_CASE(value_clear_empty_array)
{
migraphx::value values = migraphx::value::array{};
EXPECT(values.empty());
values.clear();
EXPECT(values.empty());
}
TEST_CASE(value_clear_empty_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.empty());
values.clear();
EXPECT(values.empty());
}
TEST_CASE(value_resize)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.resize(5);
EXPECT(values.size() == 5);
EXPECT(values.at(3).is_null());
EXPECT(values.at(4).is_null());
}
TEST_CASE(value_resize_with_value)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.resize(5, 7);
EXPECT(values.size() == 5);
EXPECT(values.at(3).to<int>() == 7);
EXPECT(values.at(4).to<int>() == 7);
}
TEST_CASE(value_resize_empty_array)
{
migraphx::value values = migraphx::value::array{};
EXPECT(values.is_array());
EXPECT(values.empty());
values.resize(3);
EXPECT(values.size() == 3);
EXPECT(values.at(0).is_null());
EXPECT(values.at(1).is_null());
EXPECT(values.at(2).is_null());
}
TEST_CASE(value_resize_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.is_object());
EXPECT(test::throws([&] { values.resize(4); }));
}
TEST_CASE(value_resize_n_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.is_object());
EXPECT(test::throws([&] { values.resize(4, ""); }));
}
TEST_CASE(value_assign_construct_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values = v;
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_construct_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values(v);
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_assign_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values{};
values = v;
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_init_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values = {{"a", v}};
EXPECT(values.at("a").to_vector<int>() == v);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment