Unverified Commit 24933bd8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add pointwise attribute to operators (#634)



* Add pointwise attribute

* Formatting

* Fix compilation

* Remove unused variable

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 4fdc4dfe
...@@ -570,6 +570,12 @@ inline auto has_value(T x, float tolerance = 1e-6) ...@@ -570,6 +570,12 @@ inline auto has_value(T x, float tolerance = 1e-6)
}); });
} }
inline auto has_attribute(const std::string& name)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -18,6 +18,12 @@ namespace op { ...@@ -18,6 +18,12 @@ namespace op {
struct add : binary<add> struct add : binary<add>
{ {
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x + y; }; return [](auto x, auto y) { return x + y; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -13,6 +14,8 @@ namespace op { ...@@ -13,6 +14,8 @@ namespace op {
template <class Derived> template <class Derived>
struct binary : op_name<Derived> struct binary : op_name<Derived>
{ {
value base_attributes() const { return {{"pointwise", true}}; }
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
......
...@@ -13,6 +13,12 @@ namespace op { ...@@ -13,6 +13,12 @@ namespace op {
struct equal : binary<equal> struct equal : binary<equal>
{ {
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return float_equal(x, y); }; return [](auto x, auto y) { return float_equal(x, y); };
......
...@@ -18,6 +18,12 @@ namespace op { ...@@ -18,6 +18,12 @@ namespace op {
struct max : binary<max> struct max : binary<max>
{ {
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::max(x, y); }; return [](auto x, auto y) { return std::max(x, y); };
......
...@@ -18,6 +18,12 @@ namespace op { ...@@ -18,6 +18,12 @@ namespace op {
struct min : binary<min> struct min : binary<min>
{ {
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::min(x, y); }; return [](auto x, auto y) { return std::min(x, y); };
......
...@@ -18,6 +18,12 @@ namespace op { ...@@ -18,6 +18,12 @@ namespace op {
struct mul : binary<mul> struct mul : binary<mul>
{ {
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x * y; }; return [](auto x, auto y) { return x * y; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -13,6 +14,8 @@ namespace op { ...@@ -13,6 +14,8 @@ namespace op {
template <class Derived> template <class Derived>
struct unary : op_name<Derived> struct unary : op_name<Derived>
{ {
value base_attributes() const { return {{"pointwise", true}}; }
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
......
...@@ -219,6 +219,12 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -219,6 +219,12 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
template <class T>
value attributes_op(const T&)
{
return value::object{};
}
template <class T> template <class T>
value to_value_op(const T& x) value to_value_op(const T& x)
{ {
...@@ -248,6 +254,7 @@ void from_value_op(T& x, const value& v) ...@@ -248,6 +254,7 @@ void from_value_op(T& x, const value& v)
* argument compute(const shape& output,const std::vector<argument>& input) const; * argument compute(const shape& output,const std::vector<argument>& input) const;
* value to_value() const; * value to_value() const;
* void from_value(const value& v) ; * void from_value(const value& v) ;
* value attributes() const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ; * friend bool operator==(const operation & x,const operation & y) ;
* }; * };
...@@ -377,6 +384,12 @@ struct operation ...@@ -377,6 +384,12 @@ struct operation
(*this).private_detail_te_get_handle().from_value(v); (*this).private_detail_te_get_handle().from_value(v);
} }
value attributes() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().attributes();
}
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
{ {
assert(op.private_detail_te_handle_mem_var); assert(op.private_detail_te_handle_mem_var);
...@@ -414,6 +427,7 @@ struct operation ...@@ -414,6 +427,7 @@ struct operation
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0; virtual bool operator==(const operation& y) const = 0;
}; };
...@@ -550,6 +564,19 @@ struct operation ...@@ -550,6 +564,19 @@ struct operation
detail::from_value_op(private_detail_te_self, v); detail::from_value_op(private_detail_te_self, v);
} }
template <class T>
static auto private_detail_te_default_attributes(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.attributes())
{
return private_detail_te_self.attributes();
}
template <class T>
static value private_detail_te_default_attributes(float, T&& private_detail_te_self)
{
return detail::attributes_op(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -639,6 +666,12 @@ struct operation ...@@ -639,6 +666,12 @@ struct operation
private_detail_te_default_from_value(char(0), private_detail_te_value, v); private_detail_te_default_from_value(char(0), private_detail_te_value, v);
} }
value attributes() const override
{
return private_detail_te_default_attributes(char(0), private_detail_te_value);
}
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
{ {
using migraphx::detail::operation_operators::operator<<; using migraphx::detail::operation_operators::operator<<;
......
...@@ -35,6 +35,13 @@ auto conv_const_weights() ...@@ -35,6 +35,13 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w"))); match::args(match::any(), match::is_constant().bind("w")));
} }
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
}
struct find_mul_conv struct find_mul_conv
{ {
auto matcher() const auto matcher() const
...@@ -237,7 +244,8 @@ struct find_inner_broadcast ...@@ -237,7 +244,8 @@ struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const
{ {
return match::name("mul", "add")( return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
...@@ -264,7 +272,7 @@ struct find_concat_op ...@@ -264,7 +272,7 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::name("add", "mul", "relu", "broadcast"), match::used_once())); match::any_of(pointwise(), match::name("broadcast")), match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -281,6 +289,11 @@ struct find_concat_op ...@@ -281,6 +289,11 @@ struct find_concat_op
return lens; return lens;
} }
static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
}
void apply(program& p, const match::matcher_result& r) const void apply(program& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -292,10 +305,9 @@ struct find_concat_op ...@@ -292,10 +305,9 @@ struct find_concat_op
auto x = *start; auto x = *start;
if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1) if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1)
return {start, last}; return {start, last};
auto&& name = x->name(); auto op = x->get_operator();
if(not contains({"add", "mul", "relu", "broadcast"}, name)) if(not is_valid_op(op))
return {start, last}; return {start, last};
auto op = x->get_operator();
auto iaxis = axis; auto iaxis = axis;
// Adjust broadcast lens // Adjust broadcast lens
if(op.name() == "broadcast") if(op.name() == "broadcast")
...@@ -379,8 +391,8 @@ struct find_splits ...@@ -379,8 +391,8 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(match::any_of[match::outputs()](
match::any_of[match::outputs()](match::name("add", "mul", "relu"))))); match::name("slice")(match::any_of[match::outputs()](pointwise()))));
} }
static std::vector<std::vector<instruction_ref>> static std::vector<std::vector<instruction_ref>>
......
...@@ -219,6 +219,12 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -219,6 +219,12 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
template <class T>
value attributes_op(const T&)
{
return value::object{};
}
template <class T> template <class T>
value to_value_op(const T& x) value to_value_op(const T& x)
{ {
...@@ -266,6 +272,7 @@ void from_value_op(T& x, const value& v) ...@@ -266,6 +272,7 @@ void from_value_op(T& x, const value& v)
default = 'detail::compute_op'), default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'), virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
virtual('from_value', v = 'const value&', default = 'detail::from_value_op'), virtual('from_value', v = 'const value&', default = 'detail::from_value_op'),
virtual('attributes', returns = 'value', const = True, default = 'detail::attributes_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment