Commit f9437603 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 781ce146 658cdab0
...@@ -23,6 +23,7 @@ struct identity ...@@ -23,6 +23,7 @@ struct identity
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -12,6 +12,7 @@ namespace op { ...@@ -12,6 +12,7 @@ namespace op {
struct less : binary<less> struct less : binary<less>
{ {
std::string point_function() const { return "<"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x < y; }; return [](auto x, auto y) { return x < y; };
......
...@@ -36,6 +36,7 @@ struct load ...@@ -36,6 +36,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op) friend std::ostream& operator<<(std::ostream& os, const load& op)
......
...@@ -12,6 +12,7 @@ namespace op { ...@@ -12,6 +12,7 @@ namespace op {
struct logical_and : binary<logical_and> struct logical_and : binary<logical_and>
{ {
std::string point_function() const { return "&&"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return static_cast<bool>(x) and static_cast<bool>(y); }; return [](auto x, auto y) { return static_cast<bool>(x) and static_cast<bool>(y); };
......
...@@ -12,6 +12,7 @@ namespace op { ...@@ -12,6 +12,7 @@ namespace op {
struct logical_or : binary<logical_or> struct logical_or : binary<logical_or>
{ {
std::string point_function() const { return "||"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return static_cast<bool>(x) or static_cast<bool>(y); }; return [](auto x, auto y) { return static_cast<bool>(x) or static_cast<bool>(y); };
......
...@@ -12,6 +12,7 @@ namespace op { ...@@ -12,6 +12,7 @@ namespace op {
struct logical_xor : binary<logical_xor> struct logical_xor : binary<logical_xor>
{ {
std::string point_function() const { return "^"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return static_cast<bool>(x) xor static_cast<bool>(y); }; return [](auto x, auto y) { return static_cast<bool>(x) xor static_cast<bool>(y); };
......
...@@ -24,6 +24,7 @@ struct mul : binary<mul> ...@@ -24,6 +24,7 @@ struct mul : binary<mul>
a["commutative"] = true; a["commutative"] = true;
return a; return a;
} }
std::string point_function() const { return "*"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x * y; }; return [](auto x, auto y) { return x * y; };
......
...@@ -68,6 +68,7 @@ struct multibroadcast ...@@ -68,6 +68,7 @@ struct multibroadcast
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -18,6 +18,7 @@ namespace op { ...@@ -18,6 +18,7 @@ namespace op {
struct neg : unary<neg> struct neg : unary<neg>
{ {
std::string point_function() const { return "-"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return -x; }; return [](auto x) { return -x; };
......
...@@ -18,6 +18,7 @@ namespace op { ...@@ -18,6 +18,7 @@ namespace op {
struct relu : unary<relu> struct relu : unary<relu>
{ {
std::string point_op() const { return "${function:max}(decltype(${0}){0}, ${0})"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return std::max(decltype(x){0}, x); }; return [](auto x) { return std::max(decltype(x){0}, x); };
......
...@@ -71,6 +71,7 @@ struct reshape ...@@ -71,6 +71,7 @@ struct reshape
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -39,6 +39,7 @@ struct scalar ...@@ -39,6 +39,7 @@ struct scalar
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -9,6 +9,7 @@ namespace op { ...@@ -9,6 +9,7 @@ namespace op {
struct sqdiff : binary<sqdiff> struct sqdiff : binary<sqdiff>
{ {
std::string point_op() const { return "(${0} - ${1}) * (${0} - ${1})"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return (x - y) * (x - y); }; return [](auto x, auto y) { return (x - y) * (x - y); };
......
...@@ -77,6 +77,7 @@ struct squeeze ...@@ -77,6 +77,7 @@ struct squeeze
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -18,6 +18,7 @@ namespace op { ...@@ -18,6 +18,7 @@ namespace op {
struct sub : binary<sub> struct sub : binary<sub>
{ {
std::string point_function() const { return "-"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x - y; }; return [](auto x, auto y) { return x - y; };
......
...@@ -63,6 +63,7 @@ struct transpose ...@@ -63,6 +63,7 @@ struct transpose
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,7 +15,27 @@ namespace op { ...@@ -14,7 +15,27 @@ namespace op {
template <class Derived> template <class Derived>
struct unary : op_name<Derived> struct unary : op_name<Derived>
{ {
value base_attributes() const { return {{"pointwise", true}}; } std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return pf + "${0}";
}
else
{
return "${function:" + pf + "}(${0})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -12,6 +12,7 @@ namespace op { ...@@ -12,6 +12,7 @@ namespace op {
struct unary_not : unary<unary_not> struct unary_not : unary<unary_not>
{ {
std::string point_function() const { return "!"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return not x; }; return [](auto x) { return not x; };
......
...@@ -70,6 +70,7 @@ struct unsqueeze ...@@ -70,6 +70,7 @@ struct unsqueeze
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -341,6 +341,29 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -341,6 +341,29 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
template <class T>
auto compile_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
return x.compile(auto_any_cast(ctx), output_shape, input);
}
template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
return value::object{};
}
template <class T>
value compile_op(const T& x,
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
{
return compile_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T> template <class T>
value attributes_op(const T&) value attributes_op(const T&)
{ {
...@@ -361,6 +384,12 @@ void from_value_op(T& x, const value& v) ...@@ -361,6 +384,12 @@ void from_value_op(T& x, const value& v)
return migraphx::from_value(v, x); return migraphx::from_value(v, x);
} }
template <class T>
bool is_borrowed_op(const T&)
{
return false;
}
} // namespace detail } // namespace detail
/* /*
...@@ -372,7 +401,9 @@ void from_value_op(T& x, const value& v) ...@@ -372,7 +401,9 @@ void from_value_op(T& x, const value& v)
* bool is_context_free() const; * bool is_context_free() const;
* bool need_normalization() const; * bool need_normalization() const;
* bool has_finalize() const; * bool has_finalize() const;
* bool is_borrowed() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
...@@ -475,12 +506,24 @@ struct operation ...@@ -475,12 +506,24 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
bool is_borrowed() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input); return (*this).private_detail_te_get_handle().output_alias(input);
} }
value compile(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compile(ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -569,7 +612,10 @@ struct operation ...@@ -569,7 +612,10 @@ struct operation
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0; virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual void virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0; finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
...@@ -630,6 +676,19 @@ struct operation ...@@ -630,6 +676,19 @@ struct operation
return detail::has_finalize_op(private_detail_te_self); return detail::has_finalize_op(private_detail_te_self);
} }
template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed())
{
return private_detail_te_self.is_borrowed();
}
template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
{
return detail::is_borrowed_op(private_detail_te_self);
}
template <class T> template <class T>
static auto private_detail_te_default_output_alias(char, static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self, T&& private_detail_te_self,
...@@ -647,6 +706,27 @@ struct operation ...@@ -647,6 +706,27 @@ struct operation
return detail::output_alias_op(private_detail_te_self, input); return detail::output_alias_op(private_detail_te_self, input);
} }
template <class T>
static auto private_detail_te_default_compile(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compile(ctx, output, input))
{
return private_detail_te_self.compile(ctx, output, input);
}
template <class T>
static value private_detail_te_default_compile(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
return detail::compile_op(private_detail_te_self, ctx, output, input);
}
template <class T> template <class T>
static auto private_detail_te_default_finalize(char, static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self, T&& private_detail_te_self,
...@@ -858,12 +938,25 @@ struct operation ...@@ -858,12 +938,25 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value); return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
} }
bool is_borrowed() const override
{
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{ {
return private_detail_te_default_output_alias(char(0), private_detail_te_value, input); return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
} }
value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
{
return private_detail_te_default_compile(
char(0), private_detail_te_value, ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{ {
...@@ -1010,6 +1103,24 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1010,6 +1103,24 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
dependent_type<context, Context> ctx2 = std::ref(ctx);
return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(op.compile(ctx, ctx, output_shape, input))
{
return op.compile(ctx, ctx, output_shape, input);
}
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs) inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{ {
return op.compute_shape(inputs); return op.compute_shape(inputs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment