Unverified Commit 002eb4e2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add C++ ability to construct operators by name (#616)



* Add make_op function

* Formatting

* Add more values

* Formatting

* Remove templates parse_conv functions

* Formatting

* Remove mat_mul template

* Formatting

* Reduce header includes

* Fix compiling for gpu

* Formatting

* Use make_op in lowering

* Formatting

* Sort lines

* Formatting

* Add more tests

* Formatting

* Fix tidy error

* Formatting

* Add const refs

* Add explicit this

* Add more const refs

* Sort the program

* Remove commented out code

* Formatting

* Infer gpu prefix

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 56b3bf58
...@@ -20,6 +20,7 @@ add_library(migraphx ...@@ -20,6 +20,7 @@ add_library(migraphx
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
make_op.cpp
msgpack.cpp msgpack.cpp
program.cpp program.cpp
quantization.cpp quantization.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -23,7 +23,7 @@ struct quant_dot ...@@ -23,7 +23,7 @@ struct quant_dot
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::string name() const { return "quant_dot"; } std::string name() const { return "quant_dot"; }
......
...@@ -47,7 +47,7 @@ value to_value_impl(rank<1>, const std::pair<T, U>& x) ...@@ -47,7 +47,7 @@ value to_value_impl(rank<1>, const std::pair<T, U>& x)
template <class T> template <class T>
auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
{ {
value result; value result = value::array{};
for(auto&& y : x) for(auto&& y : x)
{ {
auto e = to_value(y); auto e = to_value(y);
...@@ -59,7 +59,7 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) ...@@ -59,7 +59,7 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
value to_value_impl(rank<3>, const T& x) value to_value_impl(rank<3>, const T& x)
{ {
value result; value result = value::object{};
reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); }); reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); });
return result; return result;
} }
......
...@@ -73,6 +73,13 @@ struct value_converter<std::pair<T, U>> ...@@ -73,6 +73,13 @@ struct value_converter<std::pair<T, U>>
}; };
namespace detail { namespace detail {
template <class To, class Key, class From>
auto try_convert_value_impl(rank<2>, const std::pair<Key, From>& x)
-> decltype(value_converter<To>::apply(x.second))
{
return value_converter<To>::apply(x.second);
}
template <class To, class From> template <class To, class From>
auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x)) auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x))
{ {
...@@ -89,7 +96,7 @@ To try_convert_value_impl(rank<0>, const From& x) ...@@ -89,7 +96,7 @@ To try_convert_value_impl(rank<0>, const From& x)
template <class To, class From> template <class To, class From>
To try_convert_value(const From& x) To try_convert_value(const From& x)
{ {
return detail::try_convert_value_impl<To>(rank<1>{}, x); return detail::try_convert_value_impl<To>(rank<2>{}, x);
} }
struct value struct value
...@@ -159,6 +166,26 @@ struct value ...@@ -159,6 +166,26 @@ struct value
using is_pickable = using is_pickable =
std::integral_constant<bool, (std::is_arithmetic<T>{} and not std::is_pointer<T>{})>; std::integral_constant<bool, (std::is_arithmetic<T>{} and not std::is_pointer<T>{})>;
template <class T>
using range_value = std::decay_t<decltype(std::declval<T>().end(), *std::declval<T>().begin())>;
template <class T>
using is_generic_range =
std::integral_constant<bool,
(std::is_convertible<range_value<T>, value>{} and
not std::is_convertible<T, array>{} and
not std::is_convertible<T, object>{})>;
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const T& r) : value(from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const std::string& pkey, const T& r) : value(pkey, from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})>
value(T i) : value(pick<T>{i}) value(T i) : value(pick<T>{i})
{ {
...@@ -176,6 +203,11 @@ struct value ...@@ -176,6 +203,11 @@ struct value
{ {
return *this = pick<T>{rhs}; // NOLINT return *this = pick<T>{rhs}; // NOLINT
} }
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value& operator=(T rhs)
{
return *this = from_values(rhs); // NOLINT
}
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
...@@ -214,6 +246,10 @@ struct value ...@@ -214,6 +246,10 @@ struct value
const value& operator[](std::size_t i) const; const value& operator[](std::size_t i) const;
value& operator[](const std::string& pkey); value& operator[](const std::string& pkey);
void clear();
void resize(std::size_t n);
void resize(std::size_t n, const value& v);
std::pair<value*, bool> insert(const value& v); std::pair<value*, bool> insert(const value& v);
value* insert(const value* pos, const value& v); value* insert(const value* pos, const value& v);
...@@ -294,6 +330,14 @@ struct value ...@@ -294,6 +330,14 @@ struct value
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
private: private:
template <class T>
std::vector<value> from_values(const T& r)
{
std::vector<value> v;
std::transform(
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const; type_t get_type() const;
std::shared_ptr<value_base_impl> x; std::shared_ptr<value_base_impl> x;
std::string key; std::string key;
......
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
{
auto op = load_op(name);
// Merge values
value w = op.to_value();
for(auto&& x : v)
{
w.at(x.get_key()) = x.without_key();
}
op.from_value(w);
return op;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <migraphx/fallthrough.hpp> #include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -20,6 +20,31 @@ ...@@ -20,6 +20,31 @@
#include <migraphx/type_traits.hpp> #include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -50,63 +75,67 @@ struct onnx_parser ...@@ -50,63 +75,67 @@ struct onnx_parser
onnx_parser() onnx_parser()
{ {
// sort onnx operator alphabetically through name // sort onnx operator alphabetically through name
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", "abs");
add_generic_op("Acos", op::acos{}); add_generic_op("Acos", "acos");
add_generic_op("Acosh", op::acosh{}); add_generic_op("Acosh", "acosh");
add_generic_op("Asin", op::asin{}); add_generic_op("Asin", "asin");
add_generic_op("Asinh", op::asinh{}); add_generic_op("Asinh", "asinh");
add_generic_op("Atan", op::atan{}); add_generic_op("Atan", "atan");
add_generic_op("Atanh", op::atanh{}); add_generic_op("Atanh", "atanh");
add_generic_op("Ceil", op::ceil{}); add_generic_op("Ceil", "ceil");
add_generic_op("Cos", op::cos{}); add_generic_op("Concat", "concat");
add_generic_op("Cosh", op::cosh{}); add_generic_op("Cos", "cos");
add_generic_op("Erf", op::erf{}); add_generic_op("Cosh", "cosh");
add_generic_op("Exp", op::exp{}); add_generic_op("Dropout", "identity");
add_generic_op("Dropout", op::identity{}); add_generic_op("Erf", "erf");
add_generic_op("Floor", op::floor{}); add_generic_op("Exp", "exp");
add_generic_op("Identity", op::identity{}); add_generic_op("Flatten", "flatten");
add_generic_op("Log", op::log{}); add_generic_op("Floor", "floor");
add_generic_op("Neg", op::neg{}); add_generic_op("Gather", "gather", true);
add_generic_op("Reciprocal", op::recip{}); add_generic_op("Identity", "identity");
add_generic_op("Relu", op::relu{}); add_generic_op("Log", "log");
add_generic_op("Round", op::round{}); add_generic_op("LogSoftmax", "logsoftmax");
add_generic_op("Sigmoid", op::sigmoid{}); add_generic_op("Neg", "neg");
add_generic_op("Sign", op::sign{}); add_generic_op("Reciprocal", "recip");
add_generic_op("Sin", op::sin{}); add_generic_op("Relu", "relu");
add_generic_op("Sinh", op::sinh{}); add_generic_op("Round", "round");
add_generic_op("Sqrt", op::sqrt{}); add_generic_op("Sigmoid", "sigmoid");
add_generic_op("Tan", op::tan{}); add_generic_op("Sign", "sign");
add_generic_op("Tanh", op::tanh{}); add_generic_op("Sin", "sin");
add_generic_op("Sinh", "sinh");
add_binary_op("Add", op::add{}); add_generic_op("Softmax", "softmax");
add_binary_op("Div", op::div{}); add_generic_op("Sqrt", "sqrt");
add_binary_op("Mul", op::mul{}); add_generic_op("Squeeze", "squeeze", true);
add_binary_op("Pow", op::pow{}); add_generic_op("Tan", "tan");
add_binary_op("PRelu", op::prelu{}); add_generic_op("Tanh", "tanh");
add_binary_op("Sub", op::sub{}); add_generic_op("Unsqueeze", "unsqueeze", true);
add_variadic_op("Sum", op::add{}); add_binary_op("Add", "add");
add_variadic_op("Max", op::max{}); add_binary_op("Div", "div");
add_variadic_op("Min", op::min{}); add_binary_op("Mul", "mul");
add_binary_op("Pow", "pow");
add_binary_op("PRelu", "prelu");
add_binary_op("Sub", "sub");
add_variadic_op("Sum", "add");
add_variadic_op("Max", "max");
add_variadic_op("Min", "min");
add_mem_op("ATen", &onnx_parser::parse_aten); add_mem_op("ATen", &onnx_parser::parse_aten);
add_mem_op("AveragePool", &onnx_parser::parse_pooling); add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>); add_mem_op("ArgMax", "argmax", &onnx_parser::parse_arg_op);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>); add_mem_op("ArgMin", "argmin", &onnx_parser::parse_arg_op);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Cast", &onnx_parser::parse_cast); add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape); add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>); add_mem_op("Conv", "convolution", &onnx_parser::parse_conv);
add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>); add_mem_op("ConvInteger", "quant_convolution", &onnx_parser::parse_conv);
add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose); add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand); add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("GatherElements", &onnx_parser::parse_gather_elements); add_mem_op("GatherElements", &onnx_parser::parse_gather_elements);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling); add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
...@@ -115,11 +144,10 @@ struct onnx_parser ...@@ -115,11 +144,10 @@ struct onnx_parser
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm); add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>); add_mem_op("MatMul", "dot", &onnx_parser::parse_matmul);
add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>); add_mem_op("MatMulInteger", "quant_dot", &onnx_parser::parse_matmul);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("NonZero", &onnx_parser::parse_nonzero); add_mem_op("NonZero", &onnx_parser::parse_nonzero);
add_mem_op("OneHot", &onnx_parser::parse_onehot); add_mem_op("OneHot", &onnx_parser::parse_onehot);
...@@ -129,22 +157,19 @@ struct onnx_parser ...@@ -129,22 +157,19 @@ struct onnx_parser
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2); add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum); add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp); add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>); add_mem_op("ReduceMax", "reduce_max", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>); add_mem_op("ReduceMean", "reduce_mean", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>); add_mem_op("ReduceMin", "reduce_min", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>); add_mem_op("ReduceProd", "reduce_prod", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>); add_mem_op("ReduceSum", "reduce_sum", &onnx_parser::parse_reduce_oper);
add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square); add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
add_mem_op("Reshape", &onnx_parser::parse_reshape); add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Split", &onnx_parser::parse_split); add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Tile", &onnx_parser::parse_tile); add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -153,11 +178,39 @@ struct onnx_parser ...@@ -153,11 +178,39 @@ struct onnx_parser
void init_actv_func() void init_actv_func()
{ {
// Support name format of all lower case or the first letter capital // Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{})); map_actv_funcs.insert(std::make_pair("tanh", make_op("tanh")));
map_actv_funcs.insert(std::make_pair("relu", op::relu{})); map_actv_funcs.insert(std::make_pair("relu", make_op("relu")));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{})); map_actv_funcs.insert(std::make_pair("sigmoid", make_op("sigmoid")));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{})); map_actv_funcs.insert(std::make_pair("leakyrelu", make_op("leaky_relu")));
map_actv_funcs.insert(std::make_pair("elu", op::elu{})); map_actv_funcs.insert(std::make_pair("elu", make_op("elu")));
}
static operation load(const std::string& name, const node_info& info)
{
auto op = make_op(name);
auto v = op.to_value();
for(auto&& x : v)
{
if(info.attributes.count(x.get_key()) == 0)
continue;
literal s = parse_value(info.attributes.at(x.get_key()));
if(x.is_array())
{
std::vector<value> values;
s.visit([&](auto y) {
std::transform(y.begin(), y.end(), std::back_inserter(values), [](auto z) {
return value(z);
});
});
x = values;
}
else
{
s.visit([&](auto y) { x = y.front(); });
}
}
op.from_value(v);
return op;
} }
template <class F> template <class F>
...@@ -176,17 +229,24 @@ struct onnx_parser ...@@ -176,17 +229,24 @@ struct onnx_parser
} }
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(const std::string& name, F f)
{ {
add_op(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
template <class T> template <class F>
void add_binary_op(std::string name, T x) void add_mem_op(const std::string& onnx_name, const std::string& op_name, F f)
{ {
add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) { add_op(onnx_name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, onnx_name, op_name, std::forward<decltype(xs)>(xs)...);
});
}
void add_binary_op(const std::string& onnx_name, const std::string& op_name)
{
add_op(onnx_name, [this, op_name](node_info info, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis")) if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
...@@ -197,13 +257,13 @@ struct onnx_parser ...@@ -197,13 +257,13 @@ struct onnx_parser
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]); args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(make_op(op_name), args[0], l);
} }
return prog.add_instruction(x, args); return prog.add_instruction(make_op(op_name), args);
} }
else else
{ {
return add_broadcastable_binary_op(args[0], args[1], x); return add_broadcastable_binary_op(args[0], args[1], op_name);
} }
}); });
} }
...@@ -254,11 +314,11 @@ struct onnx_parser ...@@ -254,11 +314,11 @@ struct onnx_parser
return ins; return ins;
} }
return prog.add_instruction(op::contiguous{}, ins); return prog.add_instruction(make_op("contiguous"), ins);
} }
template <class T> instruction_ref
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, const std::string& name)
{ {
if(arg0->get_shape().lens() != arg1->get_shape().lens()) if(arg0->get_shape().lens() != arg1->get_shape().lens())
{ {
...@@ -275,31 +335,40 @@ struct onnx_parser ...@@ -275,31 +335,40 @@ struct onnx_parser
if(arg1->get_shape().lens() != out_lens) if(arg1->get_shape().lens() != out_lens)
l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1); l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(make_op(name), l0, l1);
} }
else else
{ {
return prog.add_instruction(x, {arg0, arg1}); return prog.add_instruction(make_op(name), {arg0, arg1});
} }
} }
template <class T> void add_generic_op(const std::string& onnx_name,
void add_generic_op(std::string name, T x) const std::string& op_name,
bool contiguous = false)
{
add_op(
onnx_name,
[this, op_name, contiguous](const node_info& info, std::vector<instruction_ref> args) {
auto op = load(op_name, info);
if(contiguous)
{ {
add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) { std::transform(args.begin(), args.end(), args.begin(), [&](auto arg) {
return prog.add_instruction(x, args); return this->make_contiguous(arg);
});
}
return prog.add_instruction(op, args);
}); });
} }
template <class T> void add_variadic_op(const std::string& onnx_name, const std::string& op_name)
void add_variadic_op(std::string name, T x)
{ {
add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) { add_op(onnx_name, [this, op_name](const node_info&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
[this, x](instruction_ref a, instruction_ref b) { [this, op_name](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, x); return add_broadcastable_binary_op(a, b, op_name);
}); });
}); });
} }
...@@ -318,7 +387,7 @@ struct onnx_parser ...@@ -318,7 +387,7 @@ struct onnx_parser
{ {
auto bias_bcast = auto bias_bcast =
prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]); prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, curr_ins, bias_bcast); return prog.add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
} }
...@@ -338,10 +407,9 @@ struct onnx_parser ...@@ -338,10 +407,9 @@ struct onnx_parser
return false; return false;
} }
template <class Op>
void check_asym_padding(instruction_ref& ins, void check_asym_padding(instruction_ref& ins,
const std::vector<int64_t>& padding, const std::vector<int64_t>& padding,
Op& op, value& v,
int count_include_pad = 0, int count_include_pad = 0,
float pad_val = 0) float pad_val = 0)
{ {
...@@ -360,7 +428,7 @@ struct onnx_parser ...@@ -360,7 +428,7 @@ struct onnx_parser
} }
else else
{ {
op.padding = std::vector<size_t>(left_pad_it, right_pad_it); v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it);
} }
} }
...@@ -404,29 +472,17 @@ struct onnx_parser ...@@ -404,29 +472,17 @@ struct onnx_parser
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg); max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);
if(min_used and max_used) if(min_used and max_used)
return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg); return prog.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
if(min_used) if(min_used)
return prog.add_instruction(op::max{}, args[0], min_arg); return prog.add_instruction(make_op("max"), args[0], min_arg);
return prog.add_instruction(op::identity{}, args[0]);
}
template <class Op>
instruction_ref
parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 1;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
return prog.add_instruction(Op{axis}, std::move(args)); return prog.add_instruction(make_op("identity"), args[0]);
} }
template <class Op> instruction_ref parse_arg_op(const std::string&,
instruction_ref const std::string& op_name,
parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args) node_info info,
std::vector<instruction_ref> args)
{ {
int64_t axis = 0; int64_t axis = 0;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
...@@ -442,12 +498,12 @@ struct onnx_parser ...@@ -442,12 +498,12 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = prog.add_instruction(Op{axis}, std::move(args)); auto ins = prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins); return prog.add_instruction(op::squeeze{{axis}}, ins);
} }
else else
{ {
return prog.add_instruction(Op{axis}, std::move(args)); return prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
} }
} }
...@@ -536,29 +592,27 @@ struct onnx_parser ...@@ -536,29 +592,27 @@ struct onnx_parser
} }
} }
template <class Op> void recalc_conv_attributes(value& v, size_t kdims)
void recalc_conv_attributes(Op& op, size_t kdims)
{ {
if(op.padding.size() != kdims) if(v["padding"].size() != kdims)
{ {
op.padding.resize(kdims); v["padding"].resize(kdims);
std::fill_n(op.padding.begin(), kdims, 0); std::fill_n(v["padding"].begin(), kdims, 0);
} }
if(op.stride.size() != kdims) if(v["stride"].size() != kdims)
{ {
op.stride.resize(kdims); v["stride"].resize(kdims);
std::fill_n(op.stride.begin(), kdims, 1); std::fill_n(v["stride"].begin(), kdims, 1);
} }
if(op.dilation.size() != kdims) if(v["dilation"].size() != kdims)
{ {
op.dilation.resize(kdims); v["dilation"].resize(kdims);
std::fill_n(op.dilation.begin(), kdims, 1); std::fill_n(v["dilation"].begin(), kdims, 1);
} }
} }
template <class Op>
static void cal_auto_padding_size(node_info info, static void cal_auto_padding_size(node_info info,
Op& op, value& v,
const std::vector<std::size_t>& k_lens, const std::vector<std::size_t>& k_lens,
const std::vector<std::size_t>& dilation, const std::vector<std::size_t>& dilation,
const std::vector<std::size_t>& in_lens, const std::vector<std::size_t>& in_lens,
...@@ -575,7 +629,7 @@ struct onnx_parser ...@@ -575,7 +629,7 @@ struct onnx_parser
auto auto_pad = info.attributes["auto_pad"].s(); auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos) if(auto_pad.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; v["padding_mode"] = to_value(op::padding_mode_t::same);
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos); bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
paddings.resize(2 * kdims); paddings.resize(2 * kdims);
...@@ -584,7 +638,7 @@ struct onnx_parser ...@@ -584,7 +638,7 @@ struct onnx_parser
calculate_padding(i, calculate_padding(i,
paddings, paddings,
in_lens[i + 2], in_lens[i + 2],
op.stride[i], v["stride"][i].to<int64_t>(),
dilation[i], dilation[i],
k_lens[i], k_lens[i],
is_same_upper); is_same_upper);
...@@ -606,11 +660,13 @@ struct onnx_parser ...@@ -606,11 +660,13 @@ struct onnx_parser
} }
} }
template <class Op> instruction_ref parse_conv(const std::string&,
instruction_ref const std::string& op_name,
parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args) node_info info,
std::vector<instruction_ref> args)
{ {
Op op; auto op = make_op(op_name);
auto values = op.to_value();
auto l0 = args[0]; auto l0 = args[0];
auto weights = args[1]; auto weights = args[1];
auto in_lens = l0->get_shape().lens(); auto in_lens = l0->get_shape().lens();
...@@ -622,21 +678,22 @@ struct onnx_parser ...@@ -622,21 +678,22 @@ struct onnx_parser
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
op.stride.clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride)); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV: inconsistent strides"); check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides");
} }
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
op.dilation.clear(); values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation)); copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(kdims, op.dilation.size(), "PARSE_CONV: inconsistent dilations"); check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations");
} }
std::vector<int64_t> padding; std::vector<int64_t> padding;
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
op.padding.clear(); values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(padding)); copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings"); check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
} }
...@@ -646,17 +703,23 @@ struct onnx_parser ...@@ -646,17 +703,23 @@ struct onnx_parser
auto weight_lens = weights->get_shape().lens(); auto weight_lens = weights->get_shape().lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end()); std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info, op, k_lens, op.dilation, in_lens, padding); cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
} }
check_asym_padding(l0, padding, op); check_asym_padding(l0, padding, values);
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
{ {
op.group = parse_value(info.attributes.at("group")).at<int>(); values["group"] = parse_value(info.attributes.at("group")).at<int>();
} }
recalc_conv_attributes(op, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]); auto l1 = prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1); return add_bias(args, l1, 1);
} }
...@@ -664,7 +727,9 @@ struct onnx_parser ...@@ -664,7 +727,9 @@ struct onnx_parser
instruction_ref instruction_ref
parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args) parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::deconvolution op; operation op = make_op("deconvolution");
value values = op.to_value();
// op::deconvolution op;
auto l0 = args[0]; auto l0 = args[0];
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
bool asym_padding = false; bool asym_padding = false;
...@@ -685,25 +750,26 @@ struct onnx_parser ...@@ -685,25 +750,26 @@ struct onnx_parser
{ {
size_t pad_ndims = padding.size() / 2; size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings"); check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
op.padding.clear(); values["padding"].clear();
std::transform(padding.begin(), std::transform(padding.begin(),
padding.begin() + pad_ndims, padding.begin() + pad_ndims,
std::back_inserter(op.padding), std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; }); [](auto pad_val) { return pad_val; });
} }
} }
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
op.stride.clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride)); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV_TRANSPOSE: inconsistent strides"); check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
} }
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
op.dilation.clear(); values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation)); copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes( check_attr_sizes(
kdims, op.dilation.size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations"); kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
} }
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
...@@ -716,17 +782,18 @@ struct onnx_parser ...@@ -716,17 +782,18 @@ struct onnx_parser
if(s.find("SAME") != std::string::npos) if(s.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; values["padding_mode"] = to_value(op::padding_mode_t::same);
} }
} }
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
{ {
op.group = parse_value(info.attributes.at("group")).at<int>(); values["group"] = parse_value(info.attributes.at("group")).at<int>();
} }
recalc_conv_attributes(op, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]); auto l1 = prog.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens()); std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end()); std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
...@@ -799,13 +866,13 @@ struct onnx_parser ...@@ -799,13 +866,13 @@ struct onnx_parser
} }
} }
static void tune_padding_size(const op::pooling& op, static void tune_padding_size(const value& v,
std::vector<int64_t>& padding, std::vector<int64_t>& padding,
int count_include_pad, int count_include_pad,
std::vector<int64_t>& s_start) std::vector<int64_t>& s_start)
{ {
// maxpooling or count_include_pad is 1, no change is required. // maxpooling or count_include_pad is 1, no change is required.
if(op.mode == "max" or count_include_pad == 1) if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
{ {
return; return;
} }
...@@ -821,14 +888,17 @@ struct onnx_parser ...@@ -821,14 +888,17 @@ struct onnx_parser
s_start.resize(n_dims); s_start.resize(n_dims);
for(std::size_t i = 0; i < n_dims; ++i) for(std::size_t i = 0; i < n_dims; ++i)
{ {
tune_padding_to_symmetric(padding[i], padding[i + n_dims], op.stride[i], s_start[i]); tune_padding_to_symmetric(
padding[i], padding[i + n_dims], v.at("stride")[i].to<int64_t>(), s_start[i]);
} }
} }
instruction_ref instruction_ref
parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args) parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
{ {
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"}; std::string mode = ends_with(name, "MaxPool") ? "max" : "average";
operation op = make_op("pooling", {{"mode", mode}});
value values = op.to_value();
auto l0 = args[0]; auto l0 = args[0];
auto in_lens = l0->get_shape().lens(); auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2); assert(in_lens.size() > 2);
...@@ -836,7 +906,7 @@ struct onnx_parser ...@@ -836,7 +906,7 @@ struct onnx_parser
if(starts_with(name, "Global")) if(starts_with(name, "Global"))
{ {
op.lengths = std::vector<size_t>(in_lens.begin() + 2, in_lens.end()); values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
} }
// does not support ceil_mode // does not support ceil_mode
...@@ -858,25 +928,26 @@ struct onnx_parser ...@@ -858,25 +928,26 @@ struct onnx_parser
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
op.stride.clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride)); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, op.stride.size(), "PARSE_POOLING: inconsistent strides"); check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
} }
if(contains(info.attributes, "kernel_shape")) if(contains(info.attributes, "kernel_shape"))
{ {
op.lengths.clear(); values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(op.lengths)); copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(kdims, op.lengths.size(), "PARSE_POOLING: inconsistent lengths"); check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
} }
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING"); check_padding_mode(info, "POOLING");
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((op.mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
op.padding.clear(); values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings)); copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes( check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings"); kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
...@@ -884,9 +955,14 @@ struct onnx_parser ...@@ -884,9 +955,14 @@ struct onnx_parser
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
op.padding.clear(); values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding // return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info, op, op.lengths, {1, 1}, in_lens, paddings); cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_lens,
paddings);
} }
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
...@@ -895,23 +971,23 @@ struct onnx_parser ...@@ -895,23 +971,23 @@ struct onnx_parser
std::fill_n(paddings.begin(), 2 * kdims, 0); std::fill_n(paddings.begin(), 2 * kdims, 0);
} }
if(op.padding.size() != kdims) if(values["padding"].size() != kdims)
{ {
op.padding.resize(kdims); values["padding"].resize(kdims);
std::fill_n(op.padding.begin(), kdims, 0); std::fill_n(values["padding"].begin(), kdims, 0);
} }
if(op.stride.size() != kdims) if(values["stride"].size() != kdims)
{ {
op.stride.resize(kdims); values["stride"].resize(kdims);
std::fill_n(op.stride.begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding(paddings.begin(), paddings.end()); std::vector<int64_t> orig_padding(paddings.begin(), paddings.end());
std::vector<int64_t> slice_start; std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end; std::vector<int64_t> slice_end;
tune_padding_size(op, paddings, count_include_pad, slice_start); tune_padding_size(values, paddings, count_include_pad, slice_start);
if(!slice_start.empty()) if(!slice_start.empty())
{ {
...@@ -920,7 +996,7 @@ struct onnx_parser ...@@ -920,7 +996,7 @@ struct onnx_parser
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f}; op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()}); shape padded_shape = pad.compute_shape({l0->get_shape()});
auto out_lens = op.compute_shape({padded_shape}).lens(); auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information // compute slice_end information
slice_end.resize(slice_start.size()); slice_end.resize(slice_start.size());
...@@ -931,16 +1007,17 @@ struct onnx_parser ...@@ -931,16 +1007,17 @@ struct onnx_parser
[](auto i, auto j) { return i + j; }); [](auto i, auto j) { return i + j; });
} }
check_asym_padding(l0, paddings, op, count_include_pad, pad_val); check_asym_padding(l0, paddings, values, count_include_pad, pad_val);
in_lens = l0->get_shape().lens(); in_lens = l0->get_shape().lens();
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
if(op.lengths[i] > in_lens[i + 2] + 2 * op.padding[i]) if(values["lengths"][i].to<int64_t>() >
in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
{ {
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large"); MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
} }
} }
op.from_value(values);
auto l1 = prog.add_instruction(op, l0); auto l1 = prog.add_instruction(op, l0);
if(!slice_start.empty()) if(!slice_start.empty())
{ {
...@@ -971,62 +1048,6 @@ struct onnx_parser ...@@ -971,62 +1048,6 @@ struct onnx_parser
return prog.add_instruction(op, make_contiguous(args[0])); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref
parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 1;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::flatten{axis}, args[0]);
}
instruction_ref
parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
{
// change to hande axis to be negative values
if(!contains(info.attributes, "axis"))
{
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
}
int axis = parse_value(info.attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
}
instruction_ref instruction_ref
parse_gather_elements(const std::string&, node_info info, std::vector<instruction_ref> args) parse_gather_elements(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
...@@ -1077,9 +1098,9 @@ struct onnx_parser ...@@ -1077,9 +1098,9 @@ struct onnx_parser
auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); auto l_stride = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride); l_stride = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
auto dim_diff = prog.add_instruction(op::sub{}, arg_ind, l_dim_idx); auto dim_diff = prog.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = prog.add_instruction(op::mul{}, dim_diff, l_stride); auto delta = prog.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = prog.add_instruction(op::add{}, l_shape_idx, delta); auto ind = prog.add_instruction(make_op("add"), l_shape_idx, delta);
op::gather op{0}; op::gather op{0};
return prog.add_instruction(op, arg_data, ind); return prog.add_instruction(op, arg_data, ind);
...@@ -1214,16 +1235,18 @@ struct onnx_parser ...@@ -1214,16 +1235,18 @@ struct onnx_parser
{ {
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]); l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3); return prog.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
} }
template <class Op> instruction_ref parse_matmul(const std::string&,
instruction_ref const std::string& op_name,
parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args) const node_info&,
std::vector<instruction_ref> args)
{ {
auto l0 = args[0]; auto l0 = args[0];
auto l1 = args[1]; auto l1 = args[1];
...@@ -1270,7 +1293,8 @@ struct onnx_parser ...@@ -1270,7 +1293,8 @@ struct onnx_parser
} }
} }
auto dot_res = prog.add_instruction(Op{1, 0}, bl0, bl1); auto dot_res =
prog.add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
...@@ -1326,22 +1350,22 @@ struct onnx_parser ...@@ -1326,22 +1350,22 @@ struct onnx_parser
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto mean = prog.add_instruction(op::reduce_mean{{2, 3}}, x); auto mean = prog.add_instruction(make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean); auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean);
auto l0 = prog.add_instruction(op::sqdiff{}, x, mean_bcast); auto l0 = prog.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = prog.add_instruction(op::reduce_mean{{2, 3}}, l0); auto variance = prog.add_instruction(make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = prog.add_instruction(op::sub{}, x, mean_bcast); auto l1 = prog.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = prog.add_literal(epsilon); auto epsilon_literal = prog.add_literal(epsilon);
auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal); auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = prog.add_instruction(op::multibroadcast{dims}, variance); auto variance_bcast = prog.add_instruction(op::multibroadcast{dims}, variance);
auto l2 = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast); auto l2 = prog.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = prog.add_instruction(op::rsqrt{}, l2); auto l3 = prog.add_instruction(make_op("rsqrt"), l2);
auto l4 = prog.add_instruction(op::mul{}, l1, l3); auto l4 = prog.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = prog.add_instruction(op::broadcast{1, dims}, scale); auto scale_bcast = prog.add_instruction(op::broadcast{1, dims}, scale);
; ;
auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias); auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
auto l5 = prog.add_instruction(op::mul{}, l4, scale_bcast); auto l5 = prog.add_instruction(make_op("mul"), l4, scale_bcast);
return prog.add_instruction(op::add{}, l5, bias_bcast); return prog.add_instruction(make_op("add"), l5, bias_bcast);
} }
instruction_ref instruction_ref
...@@ -1352,7 +1376,7 @@ struct onnx_parser ...@@ -1352,7 +1376,7 @@ struct onnx_parser
{ {
alpha = parse_value(info.attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
op::leaky_relu op{alpha}; auto op = make_op("leaky_relu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
...@@ -1363,7 +1387,7 @@ struct onnx_parser ...@@ -1363,7 +1387,7 @@ struct onnx_parser
{ {
alpha = parse_value(info.attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
op::elu op{alpha}; auto op = make_op("elu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
...@@ -1408,9 +1432,10 @@ struct onnx_parser ...@@ -1408,9 +1432,10 @@ struct onnx_parser
auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias}); auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled =
prog.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals); auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
} }
instruction_ref instruction_ref
...@@ -1447,7 +1472,7 @@ struct onnx_parser ...@@ -1447,7 +1472,7 @@ struct onnx_parser
// check if padding is actually being done (at least one value is nonzero) // check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{ {
return prog.add_instruction(migraphx::op::identity{}, args.front()); return prog.add_instruction(make_op("identity"), args.front());
} }
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
...@@ -2039,9 +2064,10 @@ struct onnx_parser ...@@ -2039,9 +2064,10 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
template <class T> instruction_ref parse_reduce_oper(const std::string&,
instruction_ref const std::string& op_name,
parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args) node_info info,
std::vector<instruction_ref> args)
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); std::size_t n_dim = args.front()->get_shape().lens().size();
...@@ -2063,11 +2089,11 @@ struct onnx_parser ...@@ -2063,11 +2089,11 @@ struct onnx_parser
if(keep_dims == 1) if(keep_dims == 1)
{ {
return prog.add_instruction(T{axes}, std::move(args)); return prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
} }
else else
{ {
auto ins = prog.add_instruction(T{axes}, std::move(args)); auto ins = prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins); return prog.add_instruction(op::squeeze{axes}, ins);
} }
} }
...@@ -2075,38 +2101,38 @@ struct onnx_parser ...@@ -2075,38 +2101,38 @@ struct onnx_parser
instruction_ref instruction_ref
parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto abs_ins = prog.add_instruction(op::abs{}, args[0]); auto abs_ins = prog.add_instruction(make_op("abs"), args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins}); return parse_reduce_oper({}, "reduce_sum", std::move(info), {abs_ins});
} }
instruction_ref instruction_ref
parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]); auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins}); auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
return prog.add_instruction(op::sqrt{}, sum_ins); return prog.add_instruction(make_op("sqrt"), sum_ins);
} }
instruction_ref instruction_ref
parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args)); auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), std::move(args));
return prog.add_instruction(op::log{}, sum_ins); return prog.add_instruction(make_op("log"), sum_ins);
} }
instruction_ref instruction_ref
parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto exp_ins = prog.add_instruction(op::exp{}, args[0]); auto exp_ins = prog.add_instruction(make_op("exp"), args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins}); auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {exp_ins});
return prog.add_instruction(op::log{}, sum_ins); return prog.add_instruction(make_op("log"), sum_ins);
} }
instruction_ref instruction_ref
parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]); auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins}); return parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
} }
instruction_ref instruction_ref
...@@ -2214,11 +2240,11 @@ struct onnx_parser ...@@ -2214,11 +2240,11 @@ struct onnx_parser
auto off_val = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]); auto off_val = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]); auto on_val = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto diff = prog.add_instruction(op::sub{}, on_val, off_val); auto diff = prog.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = prog.add_instruction(op::multibroadcast{lens}, off_val); auto unsq_off_val = prog.add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff); auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
auto l_mul = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val); auto l_mul = prog.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return prog.add_instruction(op::add{}, l_mul, unsq_off_val); return prog.add_instruction(make_op("add"), l_mul, unsq_off_val);
} }
instruction_ref instruction_ref
...@@ -2302,9 +2328,15 @@ struct onnx_parser ...@@ -2302,9 +2328,15 @@ struct onnx_parser
auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]); auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]);
switch(reduce_mode) switch(reduce_mode)
{ {
case reduce_mode_t::sum: l0 = prog.add_instruction(op::reduce_sum{{0}}, l0); break; case reduce_mode_t::sum:
case reduce_mode_t::mean: l0 = prog.add_instruction(op::reduce_mean{{0}}, l0); break; l0 = prog.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
case reduce_mode_t::max: l0 = prog.add_instruction(op::reduce_max{{0}}, l0); break; break;
case reduce_mode_t::mean:
l0 = prog.add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::max:
l0 = prog.add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
break;
} }
return l0; return l0;
} }
......
#include <rocblas.h>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/op/elu.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/gpu/argmax.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/gpu/argmin.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp>
#include <migraphx/gpu/sign.hpp>
#include <migraphx/gpu/cos.hpp>
#include <migraphx/gpu/tan.hpp>
#include <migraphx/gpu/sinh.hpp>
#include <migraphx/gpu/cosh.hpp>
#include <migraphx/gpu/tanh.hpp>
#include <migraphx/gpu/asin.hpp>
#include <migraphx/gpu/acos.hpp>
#include <migraphx/gpu/atan.hpp>
#include <migraphx/gpu/asinh.hpp>
#include <migraphx/gpu/acosh.hpp>
#include <migraphx/gpu/atanh.hpp>
#include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/max.hpp>
#include <migraphx/gpu/min.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/ceil.hpp>
#include <migraphx/gpu/floor.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_max.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/reduce_min.hpp>
#include <migraphx/gpu/reduce_prod.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/prelu.hpp> #include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/recip.hpp> #include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/rnn_variable_seq_lens.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/iterator_for.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -136,61 +95,59 @@ struct miopen_apply ...@@ -136,61 +95,59 @@ struct miopen_apply
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu); add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu);
add_generic_op<hip_add>("add"); add_generic_op("acos");
add_generic_op<hip_sub>("sub"); add_generic_op("acosh");
add_generic_op<hip_exp>("exp"); add_generic_op("add");
add_generic_op<hip_erf>("erf"); add_generic_op("asin");
add_generic_op<hip_log>("log"); add_generic_op("asinh");
add_generic_op<hip_sin>("sin"); add_generic_op("atan");
add_generic_op<hip_cos>("cos"); add_generic_op("atanh");
add_generic_op<hip_tan>("tan"); add_generic_op("ceil");
add_generic_op<hip_sinh>("sinh"); add_generic_op("contiguous");
add_generic_op<hip_cosh>("cosh"); add_generic_op("cos");
add_generic_op<hip_tanh>("tanh"); add_generic_op("cosh");
add_generic_op<hip_asin>("asin"); add_generic_op("div");
add_generic_op<hip_acos>("acos"); add_generic_op("erf");
add_generic_op<hip_atan>("atan"); add_generic_op("exp");
add_generic_op<hip_asinh>("asinh"); add_generic_op("floor");
add_generic_op<hip_acosh>("acosh"); add_generic_op("log");
add_generic_op<hip_atanh>("atanh"); add_generic_op("max");
add_generic_op<hip_sqrt>("sqrt"); add_generic_op("min");
add_generic_op<hip_mul>("mul"); add_generic_op("mul");
add_generic_op<hip_div>("div"); add_generic_op("pow");
add_generic_op<hip_max>("max"); add_generic_op("prelu");
add_generic_op<hip_min>("min"); add_generic_op("recip");
add_generic_op<hip_rsqrt>("rsqrt"); add_generic_op("relu");
add_generic_op<hip_round>("round"); add_generic_op("round");
add_generic_op<hip_pow>("pow"); add_generic_op("rsqrt");
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op("sigmoid");
add_generic_op<hip_relu>("relu"); add_generic_op("sign");
add_generic_op<hip_prelu>("prelu"); add_generic_op("sin");
add_generic_op<hip_sign>("sign"); add_generic_op("sinh");
add_generic_op<hip_sigmoid>("sigmoid"); add_generic_op("sqdiff");
add_generic_op<hip_ceil>("ceil"); add_generic_op("sqrt");
add_generic_op<hip_floor>("floor"); add_generic_op("sub");
add_generic_op<hip_recip>("recip"); add_generic_op("tan");
add_generic_op<miopen_contiguous>("contiguous"); add_generic_op("tanh");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op("argmax");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op("argmin");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op("clip");
add_extend_op<hip_argmax, op::argmax>("argmax"); add_extend_op("concat");
add_extend_op<hip_argmin, op::argmin>("argmin"); add_extend_op("convert");
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op("logsoftmax");
add_extend_op<hip_convert, op::convert>("convert"); add_extend_op("pad");
add_extend_op<hip_clip, op::clip>("clip"); add_extend_op("reduce_max");
add_extend_op<hip_reduce_max, op::reduce_max>("reduce_max"); add_extend_op("reduce_mean");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean"); add_extend_op("reduce_min");
add_extend_op<hip_reduce_min, op::reduce_min>("reduce_min"); add_extend_op("reduce_prod");
add_extend_op<hip_reduce_prod, op::reduce_prod>("reduce_prod"); add_extend_op("reduce_sum");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum"); add_extend_op("rnn_var_sl_last_output");
add_extend_op<hip_rnn_var_sl_shift_output, op::rnn_var_sl_shift_output>( add_extend_op("rnn_var_sl_shift_output");
"rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_shift_sequence, op::rnn_var_sl_shift_sequence>( add_extend_op("softmax");
"rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_last_output, op::rnn_var_sl_last_output>(
"rnn_var_sl_last_output");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op(); add_lrn_op();
...@@ -379,28 +336,30 @@ struct miopen_apply ...@@ -379,28 +336,30 @@ struct miopen_apply
}); });
} }
template <class T> void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); }
void add_generic_op(std::string name)
void add_generic_op(const std::string& op_name, const std::string& gpu_name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output); refs.push_back(output);
return prog->replace_instruction(ins, T{}, refs); return prog->replace_instruction(ins, make_op(gpu_name), refs);
}); });
} }
template <class T, class Op> void add_extend_op(const std::string& name) { add_extend_op(name, "gpu::" + name); }
void add_extend_op(std::string name)
void add_extend_op(const std::string& op_name, const std::string& gpu_name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = ins->get_operator();
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output); refs.push_back(output);
return prog->replace_instruction(ins, T{op}, refs); return prog->replace_instruction(ins, make_op(gpu_name, op.to_value()), refs);
}); });
} }
...@@ -472,7 +431,8 @@ struct miopen_apply ...@@ -472,7 +431,8 @@ struct miopen_apply
std::vector<float> zeros(s.elements(), 0.0f); std::vector<float> zeros(s.elements(), 0.0f);
auto l0 = prog->add_literal(literal(s, zeros)); auto l0 = prog->add_literal(literal(s, zeros));
auto output = insert_allocation(ins, s); auto output = insert_allocation(ins, s);
return prog->replace_instruction(ins, hip_sub{}, l0, ins->inputs().front(), output); return prog->replace_instruction(
ins, make_op("gpu::sub"), l0, ins->inputs().front(), output);
}); });
} }
}; };
......
...@@ -209,6 +209,14 @@ std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x) ...@@ -209,6 +209,14 @@ std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x)
return *a; return *a;
} }
std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
{
auto* a = if_array_impl(x);
if(a == nullptr)
MIGRAPHX_THROW("Expected an array or object");
return *a;
}
value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key) value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key)
{ {
auto* a = if_array_impl(x); auto* a = if_array_impl(x);
...@@ -302,15 +310,29 @@ const value& value::at(const std::string& pkey) const ...@@ -302,15 +310,29 @@ const value& value::at(const std::string& pkey) const
{ {
auto* r = find(pkey); auto* r = find(pkey);
if(r == nullptr) if(r == nullptr)
MIGRAPHX_THROW("Not an object"); MIGRAPHX_THROW("Not an object for field: " + pkey);
if(r == end()) if(r == end())
MIGRAPHX_THROW("Key not found"); MIGRAPHX_THROW("Key not found: " + pkey);
return *r; return *r;
} }
value& value::operator[](std::size_t i) { return *(begin() + i); } value& value::operator[](std::size_t i) { return *(begin() + i); }
const value& value::operator[](std::size_t i) const { return *(begin() + i); } const value& value::operator[](std::size_t i) const { return *(begin() + i); }
value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; } value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; }
void value::clear() { get_array_throw(x).clear(); }
void value::resize(std::size_t n)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n);
}
void value::resize(std::size_t n, const value& v)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n, v);
}
std::pair<value*, bool> value::insert(const value& v) std::pair<value*, bool> value::insert(const value& v)
{ {
if(v.key.empty()) if(v.key.empty())
......
...@@ -44,6 +44,9 @@ function(add_test_command NAME EXE) ...@@ -44,6 +44,9 @@ function(add_test_command NAME EXE)
# --args $<TARGET_FILE:${EXE}> ${ARGN}) # --args $<TARGET_FILE:${EXE}> ${ARGN})
set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME}) set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME})
file(MAKE_DIRECTORY ${TEST_DIR}) file(MAKE_DIRECTORY ${TEST_DIR})
if (NOT EXISTS ${TEST_DIR})
message(FATAL_ERROR "Failed to create test directory: ${TEST_DIR}")
endif()
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake" file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT " CONTENT "
# Remove previous core dump # Remove previous core dump
......
...@@ -2014,7 +2014,7 @@ TEST_CASE(transpose_gather_test) ...@@ -2014,7 +2014,7 @@ TEST_CASE(transpose_gather_test)
auto prog = optimize_onnx("transpose_gather_test.onnx"); auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p == prog); EXPECT(p.sort() == prog.sort());
} }
TEST_CASE(undefined_test) TEST_CASE(undefined_test)
......
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
...@@ -13,6 +15,35 @@ TEST_CASE(load_op) ...@@ -13,6 +15,35 @@ TEST_CASE(load_op)
} }
} }
TEST_CASE(make_op)
{
for(const auto& name : migraphx::get_operators())
{
auto op = migraphx::load_op(name);
CHECK(op == migraphx::make_op(name));
}
}
TEST_CASE(make_op_from_value1)
{
migraphx::operation x = migraphx::make_op(
"convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}, {2, 2}, {2, 2}};
EXPECT(x == y);
}
TEST_CASE(make_op_from_value2)
{
migraphx::operation x = migraphx::make_op("convolution", {{"padding", {1, 1}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}};
EXPECT(x == y);
}
TEST_CASE(make_op_invalid_key)
{
EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); }));
}
TEST_CASE(ops) TEST_CASE(ops)
{ {
auto names = migraphx::get_operators(); auto names = migraphx::get_operators();
......
...@@ -69,4 +69,35 @@ TEST_CASE(serialize_reflectable_type) ...@@ -69,4 +69,35 @@ TEST_CASE(serialize_reflectable_type)
EXPECT(v2 != v3); EXPECT(v2 != v3);
} }
TEST_CASE(serialize_empty_array)
{
std::vector<std::size_t> ints = {};
migraphx::value v = migraphx::to_value(ints);
EXPECT(v.is_array());
EXPECT(v.empty());
v.push_back(1);
EXPECT(v.size() == 1);
EXPECT(v.front().to<int>() == 1);
}
struct empty_struct
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return migraphx::pack();
}
};
TEST_CASE(serialize_empty_struct)
{
empty_struct es{};
migraphx::value v = migraphx::to_value(es);
EXPECT(v.is_object());
EXPECT(v.empty());
v["a"] = 1;
EXPECT(v.size() == 1);
EXPECT(v.at("a").to<int>() == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -563,4 +563,132 @@ TEST_CASE(print) ...@@ -563,4 +563,132 @@ TEST_CASE(print)
EXPECT(ss.str() == "{1, {one: 1, two: 2}, {1, 2}, null}"); EXPECT(ss.str() == "{1, {one: 1, two: 2}, {1, 2}, null}");
} }
TEST_CASE(value_clear)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.clear();
EXPECT(values.empty());
values.push_back(3);
EXPECT(values.size() == 1);
EXPECT(values.at(0).to<int>() == 3);
}
TEST_CASE(value_clear_non_array)
{
migraphx::value values = 1.0;
EXPECT(test::throws([&] { values.clear(); }));
}
TEST_CASE(value_clear_object)
{
migraphx::value values = {{"a", 1}, {"b", 2}};
EXPECT(values.is_object());
EXPECT(values.size() == 2);
values.clear();
EXPECT(values.empty());
values["c"] = 3;
EXPECT(values.size() == 1);
EXPECT(values.at("c").to<int>() == 3);
}
TEST_CASE(value_clear_empty_array)
{
migraphx::value values = migraphx::value::array{};
EXPECT(values.empty());
values.clear();
EXPECT(values.empty());
}
TEST_CASE(value_clear_empty_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.empty());
values.clear();
EXPECT(values.empty());
}
TEST_CASE(value_resize)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.resize(5);
EXPECT(values.size() == 5);
EXPECT(values.at(3).is_null());
EXPECT(values.at(4).is_null());
}
TEST_CASE(value_resize_with_value)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.resize(5, 7);
EXPECT(values.size() == 5);
EXPECT(values.at(3).to<int>() == 7);
EXPECT(values.at(4).to<int>() == 7);
}
TEST_CASE(value_resize_empty_array)
{
migraphx::value values = migraphx::value::array{};
EXPECT(values.is_array());
EXPECT(values.empty());
values.resize(3);
EXPECT(values.size() == 3);
EXPECT(values.at(0).is_null());
EXPECT(values.at(1).is_null());
EXPECT(values.at(2).is_null());
}
TEST_CASE(value_resize_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.is_object());
EXPECT(test::throws([&] { values.resize(4); }));
}
TEST_CASE(value_resize_n_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.is_object());
EXPECT(test::throws([&] { values.resize(4, ""); }));
}
TEST_CASE(value_assign_construct_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values = v;
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_construct_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values(v);
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_assign_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values{};
values = v;
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_init_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values = {{"a", v}};
EXPECT(values.at("a").to_vector<int>() == v);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment