Commit 31065c7d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f 6acbd4e4
...@@ -208,7 +208,7 @@ struct schedule_model ...@@ -208,7 +208,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -274,7 +274,7 @@ struct schedule_model ...@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -89,7 +90,10 @@ struct shape ...@@ -89,7 +90,10 @@ struct shape
std::size_t opt = 0; std::size_t opt = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f); static auto reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
...@@ -115,6 +119,12 @@ struct shape ...@@ -115,6 +119,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank)
shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -136,6 +146,12 @@ struct shape ...@@ -136,6 +146,12 @@ struct shape
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* The number of dimensions in the shape.
* Same as the number of indices required to get a data value.
*/
std::size_t ndim() const;
/*! /*!
* Return the number of elements in the tensor. * Return the number of elements in the tensor.
*/ */
...@@ -221,6 +237,9 @@ struct shape ...@@ -221,6 +237,9 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape
shape to_dynamic() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
...@@ -216,7 +216,7 @@ struct stream_model ...@@ -216,7 +216,7 @@ struct stream_model
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -288,7 +288,7 @@ struct stream_model ...@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -26,8 +26,11 @@ ...@@ -26,8 +26,11 @@
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,7 +44,7 @@ struct stream_range_container ...@@ -41,7 +44,7 @@ struct stream_range_container
friend std::ostream& operator<<(std::ostream& os, const stream_range_container& sr) friend std::ostream& operator<<(std::ostream& os, const stream_range_container& sr)
{ {
assert(sr.r != nullptr); assert(sr.r != nullptr);
if(!sr.r->empty()) if(not sr.r->empty())
{ {
os << sr.r->front(); os << sr.r->front();
std::for_each( std::for_each(
...@@ -59,10 +62,22 @@ inline stream_range_container<Range> stream_range(const Range& r) ...@@ -59,10 +62,22 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace detail { namespace detail {
inline void stream_write_value_impl(rank<2>, std::ostream& os, const std::string& x) { os << x; } template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype(os << x, void())
{
os << x;
}
template <class T>
void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r)
{
os << "{";
os << stream_range(r);
os << "}";
}
template <class Range> template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void()) -> decltype(r.begin(), r.end(), void())
{ {
os << "{"; os << "{";
...@@ -70,17 +85,26 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) ...@@ -70,17 +85,26 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
os << "}"; os << "}";
} }
template <class T> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
{ {
os << x; char delim = '{';
reflect_each(x, [&](auto&& y, auto name) {
os << delim;
os << name << "=";
stream_write_value_impl(rank<2>{}, os, y);
delim = ',';
});
if(delim == ',')
os << "}";
} }
} // namespace detail } // namespace detail
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<2>{}, os, x); detail::stream_write_value_impl(rank<1>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -21,22 +21,24 @@ ...@@ -21,22 +21,24 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_GREATER_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#define MIGRAPHX_GUARD_RTGLIB_GREATER_HPP #define MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#include <migraphx/gpu/oper.hpp> #include <unordered_set>
#include <migraphx/gpu/device/greater.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_greater : binary_device<hip_greater, device::greater> struct supported_segment
{ {
std::unordered_set<instruction_ref> instructions;
float metric;
}; };
} // namespace gpu using supported_segments = std::vector<supported_segment>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#endif
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp> #include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,12 +66,12 @@ struct target ...@@ -64,12 +66,12 @@ struct target
*/ */
context get_context() const; context get_context() const;
/** /**
* @brief Check how well an instruction is supported on a target with the given metric * @brief Get the ranges of instructions that are supported on a target
* @param ins Instruction to check if it's supported * @param module Module to check for supported instructions
* @param metric Used to define how the return value should be interpreted * @param metric Used to define how the quality of the support should be measured
* @return The value based on the chosen metric. Negative numbers mean unsupported * @return the supported segments of the graph
*/ */
float is_supported(T&, instruction_ref ins, support_metric m) const; supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg) ...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
} }
template <class T> template <class T>
float target_is_supported(T&, instruction_ref, support_metric) supported_segments target_find_supported(T&, const_module_ref, support_metric)
{ {
return 0; return {};
} }
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
...@@ -132,7 +134,7 @@ struct target ...@@ -132,7 +134,7 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const; supported_segments find_supported(const_module_ref mod, support_metric m) const;
// (optional) // (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
...@@ -224,10 +226,10 @@ struct target ...@@ -224,10 +226,10 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const supported_segments find_supported(const_module_ref mod, support_metric m) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m); return (*this).private_detail_te_get_handle().find_supported(mod, m);
} }
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
...@@ -261,33 +263,33 @@ struct target ...@@ -261,33 +263,33 @@ struct target
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0; virtual supported_segments find_supported(const_module_ref mod, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T> template <class T>
static auto private_detail_te_default_is_supported(char, static auto private_detail_te_default_find_supported(char,
T&& private_detail_te_self, T&& private_detail_te_self,
instruction_ref ins, const_module_ref mod,
support_metric m) support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m)) -> decltype(private_detail_te_self.find_supported(mod, m))
{ {
return private_detail_te_self.is_supported(ins, m); return private_detail_te_self.find_supported(mod, m);
} }
template <class T> template <class T>
static float private_detail_te_default_is_supported(float, static supported_segments private_detail_te_default_find_supported(float,
T&& private_detail_te_self, T&& private_detail_te_self,
instruction_ref ins, const_module_ref mod,
support_metric m) support_metric m)
{ {
return target_is_supported(private_detail_te_self, ins, m); return target_find_supported(private_detail_te_self, mod, m);
} }
template <class T> template <class T>
...@@ -349,7 +351,7 @@ struct target ...@@ -349,7 +351,7 @@ struct target
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -372,10 +374,11 @@ struct target ...@@ -372,10 +374,11 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override supported_segments find_supported(const_module_ref mod, support_metric m) const override
{ {
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m); return private_detail_te_default_find_supported(
char(0), private_detail_te_value, mod, m);
} }
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
...@@ -423,7 +426,7 @@ struct target ...@@ -423,7 +426,7 @@ struct target
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP #define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map> #include <unordered_map>
#include <string>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
...@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments struct target_assignments
{ {
void add_assignment(instruction_ref ins, const std::string& target); using iterator = std::unordered_map<instruction_ref, std::string>::const_iterator;
using value_type = std::pair<instruction_ref, std::string>;
auto begin() const { return assignments.cbegin(); } auto size() const { return assignments.size(); }
auto end() const { return assignments.cend(); } auto& at(instruction_ref ins) const { return assignments.at(ins); }
auto insert(iterator it, const std::pair<instruction_ref, std::string>& assignment)
{
return assignments.insert(it, assignment);
}
auto find(instruction_ref ins) const { return assignments.find(ins); }
auto begin() const { return assignments.begin(); }
auto end() const { return assignments.end(); }
private: private:
std::unordered_map<instruction_ref, std::string> assignments; std::unordered_map<instruction_ref, std::string> assignments;
......
...@@ -67,7 +67,7 @@ struct tensor_view ...@@ -67,7 +67,7 @@ struct tensor_view
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
bool empty() const { return m_data == nullptr || m_shape.lens().empty(); } bool empty() const { return m_data == nullptr or m_shape.lens().empty(); }
std::size_t size() const { return m_shape.elements(); } std::size_t size() const { return m_shape.elements(); }
...@@ -109,37 +109,37 @@ struct tensor_view ...@@ -109,37 +109,37 @@ struct tensor_view
T& operator[](std::size_t i) T& operator[](std::size_t i)
{ {
assert(!this->empty() && i < this->size()); assert(not this->empty() && i < this->size());
return m_data[m_shape.index(i)]; return m_data[m_shape.index(i)];
} }
const T& operator[](std::size_t i) const const T& operator[](std::size_t i) const
{ {
assert(!this->empty() && i < this->size()); assert(not this->empty() && i < this->size());
return m_data[m_shape.index(i)]; return m_data[m_shape.index(i)];
} }
T& front() T& front()
{ {
assert(!this->empty()); assert(not this->empty());
return m_data[0]; return m_data[0];
} }
const T& front() const const T& front() const
{ {
assert(!this->empty()); assert(not this->empty());
return m_data[0]; return m_data[0];
} }
T& back() T& back()
{ {
assert(!this->empty()); assert(not this->empty());
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
const T& back() const const T& back() const
{ {
assert(!this->empty()); assert(not this->empty());
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
...@@ -159,7 +159,7 @@ struct tensor_view ...@@ -159,7 +159,7 @@ struct tensor_view
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
if(!x.empty()) if(not x.empty())
{ {
os << as_number(x.front()); os << as_number(x.front());
for(std::size_t i = 1; i < x.m_shape.elements(); i++) for(std::size_t i = 1; i < x.m_shape.elements(); i++)
...@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y) ...@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{ {
for(std::size_t i = 0; i < x.get_shape().elements(); i++) for(std::size_t i = 0; i < x.get_shape().elements(); i++)
{ {
if(!float_equal(x[i], y[i])) if(not float_equal(x[i], y[i]))
return false; return false;
} }
return true; return true;
...@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y) ...@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
template <class T, class U> template <class T, class U>
bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y) bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
{ {
return !(x == y); return not(x == y);
} }
template <class T> template <class T>
......
...@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR") inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR")
{ {
if(axis >= n_dim || std::abs(axis) > n_dim) if(axis >= n_dim or std::abs(axis) > n_dim)
{ {
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range."); MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
} }
......
...@@ -184,6 +184,12 @@ struct value ...@@ -184,6 +184,12 @@ struct value
{ {
} }
explicit binary(std::size_t s) : base(s) {} explicit binary(std::size_t s) : base(s) {}
friend std::ostream& operator<<(std::ostream& os, const binary& obj)
{
os << "{binary_object: " << obj.size() << "}";
return os;
}
}; };
value() = default; value() = default;
......
...@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y) ...@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
return true; return true;
} }
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); } bool operator!=(const instruction& x, const instruction& y) { return not(x == y); }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } bool operator!=(const instruction& i, instruction_ref ref) { return not(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); } bool operator!=(instruction_ref ref, const instruction& i) { return not(i == ref); }
void instruction::add_output(instruction_ref ins) void instruction::add_output(instruction_ref ins)
{ {
...@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os, ...@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
os << "{" << ins->get_literal() << "}"; os << "{" << ins->get_literal() << "}";
} }
if(!ins->inputs().empty()) if(not ins->inputs().empty())
{ {
char delim = '('; char delim = '(';
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
...@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os, ...@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
} }
// print module inputs // print module inputs
if(!ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
std::string delim = ", ["; std::string delim = ", [";
for(auto&& mod_arg : ins->module_inputs()) for(auto&& mod_arg : ins->module_inputs())
...@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const ...@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
if(this->need_normalization()) if(this->need_normalization())
{ {
auto s = this->inputs().front()->get_shape(); auto s = this->inputs().front()->get_shape();
if(!normalize_attributes(o, s.max_lens())) if(not normalize_attributes(o, s.max_lens()))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <fstream> #include <fstream>
namespace migraphx { namespace migraphx {
......
...@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v) ...@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v)
}); });
} }
operation make_json_op(const std::string& name, const std::string& s)
{
return make_op(name, from_json_string(convert_to_json(s)));
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -34,7 +34,6 @@ ...@@ -34,7 +34,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -141,12 +140,12 @@ void module::set_bypass(bool b) { impl->bypass = b; } ...@@ -141,12 +140,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m) void module::assign(const module& m)
{ {
// copy the impl // copy the impl
if(!impl) if(not impl)
impl = std::make_unique<module_impl>(); impl = std::make_unique<module_impl>();
*impl = *m.impl; *impl = *m.impl;
// clear instructions // clear instructions
if(!impl->instructions.empty()) if(not impl->instructions.empty())
{ {
impl->clear(); impl->clear();
} }
...@@ -346,7 +345,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -346,7 +345,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
assert(out->valid(begin())); assert(out->valid(begin()));
} }
// Replacement should not be dead code unless its the last instruction // Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end())); assert(not rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty // Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(), assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(), ins->outputs().end(),
...@@ -385,9 +384,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds ...@@ -385,9 +384,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst) instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{ {
this->move_instruction(src, dst);
for(auto ins : src->inputs()) for(auto ins : src->inputs())
this->move_instruction(ins, src); {
if(not contains(this->impl->instructions, ins))
continue;
this->move_instructions(ins, dst);
}
this->move_instruction(src, dst);
return src; return src;
} }
...@@ -598,7 +601,7 @@ instruction_ref module::validate() const ...@@ -598,7 +601,7 @@ instruction_ref module::validate() const
auto inputs = i.inputs(); auto inputs = i.inputs();
bool check_order = std::all_of( bool check_order = std::all_of(
inputs.begin(), inputs.end(), [&](auto in) { return has_instruction(in); }); inputs.begin(), inputs.end(), [&](auto in) { return has_instruction(in); });
return !i.valid(impl->instructions.begin(), check_order); return not i.valid(impl->instructions.begin(), check_order);
}); });
} }
...@@ -754,7 +757,7 @@ void module::print_graph(std::ostream& os, bool brief) const ...@@ -754,7 +757,7 @@ void module::print_graph(std::ostream& os, bool brief) const
label = to_string(ins->get_operator()); label = to_string(ins->get_operator());
os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]"; os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl; os << ";" << std::endl;
if(!ins->inputs().empty()) if(not ins->inputs().empty())
{ {
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
...@@ -788,12 +791,15 @@ static std::string cpp_var_name(const std::string& name) ...@@ -788,12 +791,15 @@ static std::string cpp_var_name(const std::string& name)
static void print_make_op(std::ostream& os, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
os << "migraphx::make_op(" << enclose_name(op.name());
auto v = op.to_value(); auto v = op.to_value();
if(not v.empty()) if(not v.empty())
{ {
os << ", " os << "migraphx::make_json_op(" << enclose_name(op.name());
<< "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")"; os << ", " << enclose_name(to_json_string(v));
}
else
{
os << "migraphx::make_op(" << enclose_name(op.name());
} }
os << ")"; os << ")";
} }
...@@ -905,7 +911,7 @@ module& module::sort() ...@@ -905,7 +911,7 @@ module& module::sort()
this->move_instruction(ins, this->begin()); this->move_instruction(ins, this->begin());
for(auto child : ins->inputs()) for(auto child : ins->inputs())
{ {
if(!contains(this->impl->instructions, child)) if(not contains(this->impl->instructions, child))
{ {
continue; continue;
} }
......
...@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(contains(vec_attrs, op::normalize_attribute::include_max)) if(contains(vec_attrs, op::normalize_attribute::include_max))
{ {
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
} }
} }
else else
{ {
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
} }
...@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(contains(vec_attrs, op::normalize_attribute::include_min)) if(contains(vec_attrs, op::normalize_attribute::include_min))
{ {
if(!std::equal(min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
} }
} }
else else
{ {
if(!std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
} }
...@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
tuned = true; tuned = true;
} }
} }
if(!attrs.contains("normalize_axes")) if(not attrs.contains("normalize_axes"))
{ {
return tuned; return tuned;
} }
......
...@@ -30,7 +30,7 @@ namespace onnx { ...@@ -30,7 +30,7 @@ namespace onnx {
void recalc_conv_attributes(value& v, size_t kdims) void recalc_conv_attributes(value& v, size_t kdims)
{ {
if(not(v["padding"].size() == kdims or v["padding"].size() == kdims * 2)) if(v["padding"].size() != kdims and v["padding"].size() != kdims * 2)
{ {
v["padding"].resize(kdims); v["padding"].resize(kdims);
std::fill_n(v["padding"].begin(), kdims, 0); std::fill_n(v["padding"].begin(), kdims, 0);
......
...@@ -97,6 +97,7 @@ struct onnx_parser ...@@ -97,6 +97,7 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
......
...@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
parser.default_dyn_dim_value = options.default_dyn_dim_value; parser.default_dyn_dim_value = options.default_dyn_dim_value;
} }
if(not options.map_input_dims.empty() and not options.map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error) if(options.print_program_on_error)
{ {
...@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
parser.parse_from(std::forward<Ts>(xs)...); parser.parse_from(std::forward<Ts>(xs)...);
} }
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -187,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons ...@@ -187,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
void onnx_parser::parse_undefined(module* mod, const std::string& name) void onnx_parser::parse_undefined(module* mod, const std::string& name)
{ {
if(!contains(instructions, name)) if(not contains(instructions, name))
{ {
auto ins = mod->add_instruction(make_op("undefined")); auto ins = mod->add_instruction(make_op("undefined"));
instructions[name] = ins; instructions[name] = ins;
...@@ -256,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -256,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
...@@ -272,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -272,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// input not in initializer_data, so it is a real input // input not in initializer_data, so it is a real input
if(!contains(mod_insts, name)) if(not contains(mod_insts, name))
{ {
// ONNX specification does not specify how to deal with the // ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
...@@ -359,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -359,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
all_output_names.begin(), all_output_names.begin(),
all_output_names.end(), all_output_names.end(),
std::back_inserter(prog_output_names), std::back_inserter(prog_output_names),
[&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); }); [&](const auto& name) { return not(name.empty() or instructions.count(name) == 0); });
std::vector<instruction_ref> output_ins; std::vector<instruction_ref> output_ins;
std::transform(prog_output_names.begin(), std::transform(prog_output_names.begin(),
...@@ -449,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -449,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const const std::vector<std::size_t>& input_dims) const
{ {
shape::type_t shape_type = get_type(t.tensor_type().elem_type()); shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(!input_dims.empty()) if(not input_dims.empty())
{ {
return {shape_type, input_dims}; return {shape_type, input_dims};
} }
...@@ -516,7 +511,7 @@ shape::type_t get_type(int dtype) ...@@ -516,7 +511,7 @@ shape::type_t get_type(int dtype)
bool is_type_float(shape::type_t dtype) bool is_type_float(shape::type_t dtype)
{ {
bool r = false; bool r = false;
if(dtype == shape::float_type || dtype == shape::double_type || dtype == shape::half_type) if(dtype == shape::float_type or dtype == shape::double_type or dtype == shape::half_type)
{ {
r = true; r = true;
} }
......
...@@ -42,7 +42,7 @@ void cal_auto_padding_size(onnx_parser::node_info info, ...@@ -42,7 +42,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
size_t kdims = in_lens.size() - 2; size_t kdims = in_lens.size() - 2;
assert(k_lens.size() == kdims and dilation.size() == kdims); assert(k_lens.size() == kdims and dilation.size() == kdims);
if(!contains(info.attributes, "auto_pad")) if(not contains(info.attributes, "auto_pad"))
{ {
return; return;
} }
...@@ -124,7 +124,7 @@ void tune_padding_size(const value& v, ...@@ -124,7 +124,7 @@ void tune_padding_size(const value& v,
} }
// if padding is symmetric, return directly // if padding is symmetric, return directly
if(!is_asym_padding(padding)) if(not is_asym_padding(padding))
{ {
return; return;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment