Commit bf6f82d8 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 6a0797e2 b93f5320
...@@ -18,19 +18,20 @@ namespace op { ...@@ -18,19 +18,20 @@ namespace op {
struct leaky_relu struct leaky_relu
{ {
std::string name() const { return "leaky_relu"; }
float alpha; float alpha;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
}; };
} // namespace op } // namespace op
......
...@@ -39,7 +39,7 @@ struct load ...@@ -39,7 +39,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op) friend std::ostream& operator<<(std::ostream& os, const load& op)
{ {
......
...@@ -29,7 +29,7 @@ struct logsoftmax ...@@ -29,7 +29,7 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size()) if(axis < 0 || axis > inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
...@@ -42,7 +42,7 @@ struct multibroadcast ...@@ -42,7 +42,7 @@ struct multibroadcast
std::vector<size_t> bcast_strides(output_lens.size(), 0); std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size() - input.lens().size();
for(int i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] == input.lens()[i]) if(output_lens[i + offset] == input.lens()[i])
{ {
...@@ -55,7 +55,7 @@ struct multibroadcast ...@@ -55,7 +55,7 @@ struct multibroadcast
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -49,20 +50,19 @@ struct pooling ...@@ -49,20 +50,19 @@ struct pooling
if(padding_mode == default_) if(padding_mode == default_)
{ {
return { return {t,
t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) / floor_divide<std::ptrdiff_t>(
static_cast<float>(stride[0]))) + input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) / floor_divide<std::ptrdiff_t>(
static_cast<float>(stride[1]))) + input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) +
1)), 1)),
}}; }};
} }
...@@ -71,27 +71,22 @@ struct pooling ...@@ -71,27 +71,22 @@ struct pooling
return {t, return {t,
{input.lens()[0], {input.lens()[0],
input.lens()[1], input.lens()[1],
static_cast<std::size_t>( ceil_divide<std::size_t>(input.lens()[2], stride[0]),
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])), ceil_divide<std::size_t>(input.lens()[3], stride[1])}};
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
} }
else if(padding_mode == valid) else if(padding_mode == valid)
{ {
return {t, return {
t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) / floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)),
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) / floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)),
static_cast<float>(stride[1]))) +
1)),
}}; }};
} }
else else
......
...@@ -66,7 +66,7 @@ struct reshape ...@@ -66,7 +66,7 @@ struct reshape
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -40,7 +40,7 @@ struct scalar ...@@ -40,7 +40,7 @@ struct scalar
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -86,7 +86,7 @@ struct slice ...@@ -86,7 +86,7 @@ struct slice
auto offset = compute_offset(input.get_shape()) * output_shape.type_size(); auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }}; return {std::move(output_shape), [=] { return input.data() + offset; }};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -70,7 +70,7 @@ struct squeeze ...@@ -70,7 +70,7 @@ struct squeeze
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -57,7 +57,7 @@ struct transpose ...@@ -57,7 +57,7 @@ struct transpose
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -13,7 +13,15 @@ struct unary : op_name<Derived> ...@@ -13,7 +13,15 @@ struct unary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -29,10 +29,14 @@ struct unsqueeze ...@@ -29,10 +29,14 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(input_shape.scalar())
return shape{type, old_lens};
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
...@@ -53,7 +57,7 @@ struct unsqueeze ...@@ -53,7 +57,7 @@ struct unsqueeze
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -49,7 +49,7 @@ struct operation ...@@ -49,7 +49,7 @@ struct operation
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const; argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If /// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned. /// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const; std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not /// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name. /// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
...@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -69,7 +69,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{ {
os << x.name(); os << x.name();
char delim = '['; char delim = '[';
reflect_each(x, [&](auto& y, auto name) { reflect_each(x, [&](auto&& y, auto name) {
os << delim; os << delim;
os << name << "="; os << name << "=";
stream_write_value(os, y); stream_write_value(os, y);
...@@ -87,6 +87,8 @@ namespace operation_equal { ...@@ -87,6 +87,8 @@ namespace operation_equal {
template <class T, class U> template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{ {
static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
"Missing equality operator or reflect method.");
if(x.name() != y.name()) if(x.name() != y.name())
return false; return false;
const auto& yy = any_cast<T>(y); const auto& yy = any_cast<T>(y);
...@@ -175,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( ...@@ -175,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
} }
template <class T> template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&) std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{ {
return -1; return -1;
} }
...@@ -188,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes) ...@@ -188,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
} }
template <class T> template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes) std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{ {
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
...@@ -239,7 +241,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -239,7 +241,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
* std::string name() const; * std::string name() const;
* bool is_context_free() const; * bool is_context_free() const;
* bool has_finalize() const; * bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
...@@ -325,7 +327,7 @@ struct operation ...@@ -325,7 +327,7 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
int output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input); return (*this).private_detail_te_get_handle().output_alias(input);
...@@ -383,7 +385,7 @@ struct operation ...@@ -383,7 +385,7 @@ struct operation
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual void virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0; finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
...@@ -432,7 +434,7 @@ struct operation ...@@ -432,7 +434,7 @@ struct operation
bool has_finalize() const override { return has_finalize_op(private_detail_te_value); } bool has_finalize() const override { return has_finalize_op(private_detail_te_value); }
int output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{ {
return output_alias_op(private_detail_te_value, input); return output_alias_op(private_detail_te_value, input);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
......
...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins); ...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
struct program struct program
{ {
program(); program();
// move constructor
program(program&&) noexcept; program(program&&) noexcept;
program& operator=(program&&) noexcept;
// copy constructor
program(const program&);
// copy assignment operator
program& operator=(program);
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>; using parameter_map = std::unordered_map<std::string, argument>;
...@@ -118,6 +126,9 @@ struct program ...@@ -118,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void assign(const program& p);
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
......
...@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
struct reflect_placeholder
{
template <class... Ts>
int operator()(Ts&&...) const
{
return 0;
}
};
template <class T, class Selector> template <class T, class Selector>
auto reflect_impl(rank<1>, T& x, Selector f) -> decltype(T::reflect(x, f)) auto reflect_impl(rank<1>, T& x, Selector f) -> decltype(T::reflect(x, f))
{ {
...@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector) ...@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector)
return pack(); return pack();
} }
template <class T>
auto reflectable_impl(rank<1>, T&& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});
template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{});
template <class T>
struct remove_rvalue_reference
{
using type = T;
};
template <class T>
struct remove_rvalue_reference<T&&>
{
using type = T;
};
template <class T>
struct wrapper
{
using type = typename remove_rvalue_reference<T>::type;
type data;
type get() const { return data; }
};
template <class T>
wrapper<T> wrap(std::remove_reference_t<T>& x)
{
return wrapper<T>{std::forward<T>(x)};
}
template <class... Ts>
using auto_tuple_t = std::tuple<typename remove_rvalue_reference<Ts>::type...>;
template <class... Ts>
auto_tuple_t<Ts...> auto_tuple(Ts&&... xs)
{
return auto_tuple_t<Ts...>{std::forward<Ts>(xs)...};
}
} // namespace detail } // namespace detail
template <class T>
using is_reflectable = decltype(detail::reflectable_impl(rank<1>{}, std::declval<T>()));
template <class T, class Selector> template <class T, class Selector>
auto reflect(T& x, Selector f) auto reflect(T& x, Selector f)
{ {
...@@ -34,15 +88,16 @@ auto reflect(T& x, Selector f) ...@@ -34,15 +88,16 @@ auto reflect(T& x, Selector f)
template <class T> template <class T>
auto reflect_tie(T& x) auto reflect_tie(T& x)
{ {
return reflect(x, [](auto&& y, auto&&...) { return std::ref(y); })( return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })(
[](auto&&... xs) { return std::tie(xs.get()...); }); [](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
} }
template <class T, class F> template <class T, class F>
void reflect_each(T& x, F f) void reflect_each(T& x, F f)
{ {
return reflect(x, [](auto&& y, auto... ys) { return pack(std::ref(y), ys...); })( return reflect(x, [](auto&& y, auto... ys) {
[&](auto&&... xs) { return pack(detail::wrap<decltype(y)>(y), ys...);
})([&](auto&&... xs) {
each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y.get(), ys...); }); }, xs...); each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y.get(), ys...); }); }, xs...);
}); });
} }
......
...@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim) ...@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim)
return ""; return "";
auto nit = std::next(it); auto nit = std::next(it);
return std::accumulate( return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
nit, strings.end(), *it, [&](std::string x, std::string y) { return x + delim + y; }); return std::move(x) + delim + std::move(y);
});
} }
template <class F> template <class F>
......
...@@ -28,6 +28,12 @@ void instruction::replace(const shape& r) ...@@ -28,6 +28,12 @@ void instruction::replace(const shape& r)
} }
} }
void instruction::replace(operation o)
{
op = std::move(o);
recompute_shape();
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::clear_arguments() void instruction::clear_arguments()
...@@ -162,7 +168,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,7 +168,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
argument instruction::eval() const bool instruction::can_eval() const
{
if(op.name() == "@literal")
{
return true;
}
else if(is_context_free(op))
{
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
else
{
return false;
}
}
argument instruction::eval(bool check_eval) const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
{ {
...@@ -170,14 +193,13 @@ argument instruction::eval() const ...@@ -170,14 +193,13 @@ argument instruction::eval() const
} }
if(is_context_free(op)) if(is_context_free(op))
{ {
std::vector<argument> args; if(check_eval and not this->can_eval())
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {}; return {};
args.push_back(a); std::vector<argument> args;
} std::transform(this->inputs().begin(),
this->inputs().end(),
std::back_inserter(args),
[](auto arg) { return arg->eval(false); });
return op.compute(result, args); return op.compute(result, args);
} }
return {}; return {};
......
...@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path) ...@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path)
labels[i] = *pimage++; labels[i] = *pimage++;
for(size_t j = 0; j < nbytes_per_image; j++) for(size_t j = 0; j < nbytes_per_image; j++)
{ {
float v = *(pimage + j) / 255.0f; float v = float(*(pimage + j)) / 255.0f;
data[i * nbytes_per_image + j] = v; data[i * nbytes_per_image + j] = v;
} }
} }
......
...@@ -63,6 +63,7 @@ struct onnx_parser ...@@ -63,6 +63,7 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
...@@ -207,7 +208,7 @@ struct onnx_parser ...@@ -207,7 +208,7 @@ struct onnx_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -215,7 +216,7 @@ struct onnx_parser ...@@ -215,7 +216,7 @@ struct onnx_parser
template <class T> template <class T>
void add_variadic_op(std::string name, T x) void add_variadic_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
...@@ -225,6 +226,22 @@ struct onnx_parser ...@@ -225,6 +226,22 @@ struct onnx_parser
}); });
} }
instruction_ref parse_clip(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::clip op;
if(contains(attributes, "max"))
{
op.max_val = parse_value(attributes.at("max")).at<float>();
}
if(contains(attributes, "min"))
{
op.min_val = parse_value(attributes.at("min")).at<float>();
}
return prog.add_instruction(op, std::move(args));
}
instruction_ref instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
...@@ -1361,28 +1378,26 @@ struct onnx_parser ...@@ -1361,28 +1378,26 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t) static literal parse_tensor(const onnx::TensorProto& t)
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
{
dims = {1};
}
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()}; case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()}; return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()}; case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()}; case onnx::TensorProto::FLOAT16:
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()}; return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1394,21 +1409,21 @@ struct onnx_parser ...@@ -1394,21 +1409,21 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()}; return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()}; return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1417,11 +1432,10 @@ struct onnx_parser ...@@ -1417,11 +1432,10 @@ struct onnx_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()}; return create_literal(shape::half_type, dims, data_half);
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return create_literal(shape::double_type, dims, t.double_data());
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1430,6 +1444,23 @@ struct onnx_parser ...@@ -1430,6 +1444,23 @@ struct onnx_parser
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
if(dims.empty())
return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment