Unverified Commit 24933bd8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add pointwise attribute to operators (#634)



* Add pointwise attribute

* Formatting

* Fix compilation

* Remove unused variable

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 4fdc4dfe
......@@ -570,6 +570,12 @@ inline auto has_value(T x, float tolerance = 1e-6)
});
}
inline auto has_attribute(const std::string& name)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -18,6 +18,12 @@ namespace op {
struct add : binary<add>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const
{
return [](auto x, auto y) { return x + y; };
......
......@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -13,6 +14,8 @@ namespace op {
template <class Derived>
struct binary : op_name<Derived>
{
value base_attributes() const { return {{"pointwise", true}}; }
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
......
......@@ -13,6 +13,12 @@ namespace op {
struct equal : binary<equal>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const
{
return [](auto x, auto y) { return float_equal(x, y); };
......
......@@ -18,6 +18,12 @@ namespace op {
struct max : binary<max>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const
{
return [](auto x, auto y) { return std::max(x, y); };
......
......@@ -18,6 +18,12 @@ namespace op {
struct min : binary<min>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const
{
return [](auto x, auto y) { return std::min(x, y); };
......
......@@ -18,6 +18,12 @@ namespace op {
struct mul : binary<mul>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
auto apply() const
{
return [](auto x, auto y) { return x * y; };
......
......@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -13,6 +14,8 @@ namespace op {
template <class Derived>
struct unary : op_name<Derived>
{
value base_attributes() const { return {{"pointwise", true}}; }
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
......
......@@ -219,6 +219,12 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
value attributes_op(const T&)
{
return value::object{};
}
template <class T>
value to_value_op(const T& x)
{
......@@ -248,6 +254,7 @@ void from_value_op(T& x, const value& v)
* argument compute(const shape& output,const std::vector<argument>& input) const;
* value to_value() const;
* void from_value(const value& v) ;
* value attributes() const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
......@@ -377,6 +384,12 @@ struct operation
(*this).private_detail_te_get_handle().from_value(v);
}
value attributes() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().attributes();
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
{
assert(op.private_detail_te_handle_mem_var);
......@@ -414,6 +427,7 @@ struct operation
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
......@@ -550,6 +564,19 @@ struct operation
detail::from_value_op(private_detail_te_self, v);
}
template <class T>
static auto private_detail_te_default_attributes(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.attributes())
{
return private_detail_te_self.attributes();
}
template <class T>
static value private_detail_te_default_attributes(float, T&& private_detail_te_self)
{
return detail::attributes_op(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -639,6 +666,12 @@ struct operation
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
value attributes() const override
{
return private_detail_te_default_attributes(char(0), private_detail_te_value);
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using migraphx::detail::operation_operators::operator<<;
......
......@@ -35,6 +35,13 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w")));
}
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
}
struct find_mul_conv
{
auto matcher() const
......@@ -237,7 +244,8 @@ struct find_inner_broadcast
{
auto matcher() const
{
return match::name("mul", "add")(
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
......@@ -264,7 +272,7 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::name("add", "mul", "relu", "broadcast"), match::used_once()));
match::any_of(pointwise(), match::name("broadcast")), match::used_once()));
}
template <class Iterator>
......@@ -281,6 +289,11 @@ struct find_concat_op
return lens;
}
static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -292,10 +305,9 @@ struct find_concat_op
auto x = *start;
if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1)
return {start, last};
auto&& name = x->name();
if(not contains({"add", "mul", "relu", "broadcast"}, name))
auto op = x->get_operator();
if(not is_valid_op(op))
return {start, last};
auto op = x->get_operator();
auto iaxis = axis;
// Adjust broadcast lens
if(op.name() == "broadcast")
......@@ -379,8 +391,8 @@ struct find_splits
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](match::name("slice")(
match::any_of[match::outputs()](match::name("add", "mul", "relu")))));
return match::any(match::any_of[match::outputs()](
match::name("slice")(match::any_of[match::outputs()](pointwise()))));
}
static std::vector<std::vector<instruction_ref>>
......
......@@ -219,6 +219,12 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
value attributes_op(const T&)
{
return value::object{};
}
template <class T>
value to_value_op(const T& x)
{
......@@ -266,6 +272,7 @@ void from_value_op(T& x, const value& v)
default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
virtual('from_value', v = 'const value&', default = 'detail::from_value_op'),
virtual('attributes', returns = 'value', const = True, default = 'detail::attributes_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment