Commit baac1dab authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-host-lib

parents 830dff7a 77042e30
...@@ -33,15 +33,36 @@ ...@@ -33,15 +33,36 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void unregister_op(const std::string& op_name);
namespace detail {
struct op_handler
{
operation op;
std::string name;
op_handler(const operation& op_r) : op(op_r), name(op.name()){};
~op_handler() { unregister_op(name); }
};
} // namespace detail
void register_op_init();
void register_op(const operation& op); void register_op(const operation& op);
operation load_op(const std::string& name); operation load_op(const std::string& name);
bool has_op(const std::string& name); bool has_op(const std::string& name);
std::vector<std::string> get_operators(); std::vector<std::string> get_operators();
template <class T> template <class T>
void register_op() void register_op()
{ {
register_op(T{}); register_op_init(); // instantiate static op_map;
static auto op_h = detail::op_handler(T{});
register_op(op_h.op);
} }
struct register_op_action struct register_op_action
......
...@@ -33,14 +33,28 @@ ...@@ -33,14 +33,28 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void register_target_init();
void register_target(const target& t); void register_target(const target& t);
void unregister_target(const std::string& name);
target make_target(const std::string& name); target make_target(const std::string& name);
std::vector<std::string> get_targets(); std::vector<std::string> get_targets();
namespace detail {
struct target_handler
{
target t;
std::string target_name;
target_handler(const target& t_r) : t(t_r), target_name(t.name()) {}
~target_handler() { unregister_target(target_name); }
};
} // namespace detail
template <class T> template <class T>
void register_target() void register_target()
{ {
register_target(T{}); register_target_init();
static auto t_h = detail::target_handler(T{});
register_target(t_h.t);
} }
struct register_target_action struct register_target_action
......
...@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
/**
* Replace `allocate` instructions with target allocations or output parameters.
*/
struct replace_allocate struct replace_allocate
{ {
allocation_model model; allocation_model model;
......
...@@ -93,7 +93,7 @@ auto to_value_impl(rank<4>, const optional<T>& x) ...@@ -93,7 +93,7 @@ auto to_value_impl(rank<4>, const optional<T>& x)
{ {
value result{}; value result{};
if(x.has_value()) if(x.has_value())
to_value(*x); return to_value(*x);
return result; return result;
} }
...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x) ...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<4>, const value& v, T& x)
-> decltype(x.insert(*x.begin()), std::declval<typename T::mapped_type>(), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
...@@ -212,28 +213,22 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x) ...@@ -212,28 +213,22 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
x = from_value<T>(v); x = from_value<T>(v);
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{} or std::is_enum<T>{})>
void from_value_impl(rank<7>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); }
void from_value_impl(rank<8>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<9>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<11>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -249,7 +244,7 @@ value to_value(const T& x) ...@@ -249,7 +244,7 @@ value to_value(const T& x)
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<11>{}, v, x); detail::from_value_impl(rank<10>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -29,10 +29,12 @@ ...@@ -29,10 +29,12 @@
#include <ostream> #include <ostream>
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <set>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -87,12 +89,12 @@ struct shape ...@@ -87,12 +89,12 @@ struct shape
{ {
std::size_t min = 0; std::size_t min = 0;
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::set<std::size_t> optimals{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt")); return pack(f(self.min, "min"), f(self.max, "max"), f(self.optimals, "optimals"));
} }
bool is_fixed() const; bool is_fixed() const;
...@@ -132,11 +134,12 @@ struct shape ...@@ -132,11 +134,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank) // Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
shape(type_t t, shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts); std::vector<std::set<std::size_t>> optimals_list);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
...@@ -186,21 +189,21 @@ struct shape ...@@ -186,21 +189,21 @@ struct shape
/*! /*!
* Minimum lengths for dynamic shape. * Minimum lengths for dynamic shape.
* lens() for fixed shape. * lens() for static shape.
*/ */
std::vector<std::size_t> min_lens() const; std::vector<std::size_t> min_lens() const;
/*! /*!
* Maximum lengths for dynamic shape. * Maximum lengths for dynamic shape.
* lens() for fixed shape. * lens() for static shape.
*/ */
std::vector<std::size_t> max_lens() const; std::vector<std::size_t> max_lens() const;
/*! /*!
* Optimum lengths for dynamic shape. * Optimum lengths for dynamic shape.
* lens() for fixed shape. * Empty for static shape.
*/ */
std::vector<std::size_t> opt_lens() const; std::vector<std::set<std::size_t>> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
...@@ -219,11 +222,15 @@ struct shape ...@@ -219,11 +222,15 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const; /// Map element index to multi-dimensional index
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; std::vector<std::size_t> multi(std::size_t idx) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no /// Map element index to multi-dimensional index and put them them into location provided by
/// padding /// pointers
void multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
...@@ -243,6 +250,9 @@ struct shape ...@@ -243,6 +250,9 @@ struct shape
/// Return true if the shape is dynamic /// Return true if the shape is dynamic
bool dynamic() const; bool dynamic() const;
/// Return true if this shape or any of the sub_shapes are dynamic
bool any_of_dynamic() const;
shape normalize_standard() const; shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
...@@ -250,9 +260,12 @@ struct shape ...@@ -250,9 +260,12 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape // convert the shape to an equivalent dynamic shape with empty optimals
shape to_dynamic() const; shape to_dynamic() const;
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct split_single_dyn_dim
{
std::string name() const { return "split_single_dyn_dim"; }
void apply(module_pass_manager&) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -43,6 +43,8 @@ struct tf_options ...@@ -43,6 +43,8 @@ struct tf_options
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, const tf_options& options = tf_options{}); program parse_tf(const std::string& name, const tf_options& options = tf_options{});
std::vector<std::string> get_tf_operators();
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -166,6 +166,7 @@ void module::assign(const module& m) ...@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(), copy_ins = impl->insert(impl->instructions.end(),
{builtin::param{name, order}, std::move(s), {}}); {builtin::param{name, order}, std::move(s), {}});
impl->nparams++;
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
...@@ -594,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const ...@@ -594,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
} }
} }
std::vector<instruction_ref> module::get_returns() const
{
auto last = std::prev(this->end());
if(last->name() == "@return")
return last->inputs();
return {last};
}
instruction_ref module::validate() const instruction_ref module::validate() const
{ {
return std::find_if( return std::find_if(
......
...@@ -172,6 +172,22 @@ struct vector_stream ...@@ -172,6 +172,22 @@ struct vector_stream
} }
}; };
struct writer_stream
{
std::function<void(const char*, std::size_t)> writer;
writer_stream& write(const char* b, std::size_t n)
{
writer(b, n);
return *this;
}
};
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer)
{
writer_stream ws{std::move(writer)};
msgpack::pack(ws, v);
}
std::vector<char> to_msgpack(const value& v) std::vector<char> to_msgpack(const value& v)
{ {
vector_stream vs; vector_stream vs;
......
...@@ -30,17 +30,22 @@ ...@@ -30,17 +30,22 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// different attributes /**
// 1) use_input(default)/use_output * Parameters:
// 2) use_rank(default)/use_len * vec: the vector attribute to normalize
// 3) clip_min(default)/not_clip_min * axes: the operator's axes attribute if it exists, empty otherwise
// 3.1) include_min(default)/exclude_min * val: the normalize_axes key and options. Ex: normalize["axes"] =
// 4) clip_max(default)/not_clip_max * value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
// 4.1) exclude_max(default)/include_max * normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens) const std::vector<std::size_t>& lens,
Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = lens.size();
...@@ -81,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -81,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
} }
...@@ -121,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -121,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal( if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
} }
...@@ -151,6 +156,11 @@ auto tune_pad_attribute(const value& val) ...@@ -151,6 +156,11 @@ auto tune_pad_attribute(const value& val)
return result; return result;
} }
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{ {
bool tuned = false; bool tuned = false;
...@@ -158,9 +168,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -158,9 +168,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding")) if(attrs.contains("normalize_padding"))
{ {
auto padding = val.at(attrs.at("normalize_padding").to<std::string>()); auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size(); auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (lens.size() - padding_start))
...@@ -186,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -186,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto vv = val.at(key).without_key(); auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -195,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -195,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -204,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -204,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -94,7 +94,7 @@ struct onnx_parser ...@@ -94,7 +94,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
......
...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val != 0) if(dim_val != 0)
{ {
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0}) if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1})
{ {
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value" MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"); "set to non-default value");
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = {dim_val, dim_val};
} }
} }
else else
......
...@@ -41,6 +41,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -41,6 +41,20 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REMOVE_LAST_OUTPUT); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REMOVE_LAST_OUTPUT);
static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims)
{
if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) {
return d.max;
});
return {shape_type, dims};
}
return {shape_type, dyn_dims};
}
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -302,7 +316,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -302,7 +316,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
else if(map_dyn_input_dims.count(name) > 0) else if(map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)}; s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name));
} }
else else
{ {
...@@ -496,7 +510,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -496,7 +510,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value; return default_dyn_dim_value;
} }
std::size_t tmp = d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0}; return {tmp, tmp};
} }
else else
{ {
...@@ -508,16 +522,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -508,16 +522,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{ {
return {shape_type}; return {shape_type};
} }
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); })) return shape_from_dyn_dims(shape_type, dynamic_dims);
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
...@@ -32,8 +32,7 @@ namespace onnx { ...@@ -32,8 +32,7 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm> struct parse_instancenorm : op_parser<parse_instancenorm>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; } std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
......
...@@ -33,8 +33,7 @@ namespace onnx { ...@@ -33,8 +33,7 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean> struct parse_mean : op_parser<parse_mean>
{ {
const std::set<shape::type_t> float_types = { std::set<shape::type_t> float_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; } std::vector<op_desc> operators() const { return {{"Mean"}}; }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size(); auto n_dim = input_lens.size();
instruction_ref y_scale; instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1) if(args[1]->get_shape().elements() != 1)
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else
{ auto common_args = add_common_args(*info.mod, {args[0], y_scale});
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); common_args.push_back(y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); return info.add_instruction(make_op("quantizelinear"), common_args);
} }
}; };
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; } std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; }
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; } std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; }
......
...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), auto cont = info.add_instruction(make_op("contiguous"), args[0]);
info.make_contiguous(args[0])); return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment