Commit f9437603 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 781ce146 658cdab0
......@@ -23,6 +23,7 @@ struct identity
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -12,6 +12,7 @@ namespace op {
struct less : binary<less>
{
std::string point_function() const { return "<"; }
auto apply() const
{
return [](auto x, auto y) { return x < y; };
......
......@@ -36,6 +36,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
......
......@@ -12,6 +12,7 @@ namespace op {
struct logical_and : binary<logical_and>
{
std::string point_function() const { return "&&"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) and static_cast<bool>(y); };
......
......@@ -12,6 +12,7 @@ namespace op {
struct logical_or : binary<logical_or>
{
std::string point_function() const { return "||"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) or static_cast<bool>(y); };
......
......@@ -12,6 +12,7 @@ namespace op {
struct logical_xor : binary<logical_xor>
{
std::string point_function() const { return "^"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) xor static_cast<bool>(y); };
......
......@@ -24,6 +24,7 @@ struct mul : binary<mul>
a["commutative"] = true;
return a;
}
std::string point_function() const { return "*"; }
auto apply() const
{
return [](auto x, auto y) { return x * y; };
......
......@@ -68,6 +68,7 @@ struct multibroadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -18,6 +18,7 @@ namespace op {
struct neg : unary<neg>
{
std::string point_function() const { return "-"; }
auto apply() const
{
return [](auto x) { return -x; };
......
......@@ -18,6 +18,7 @@ namespace op {
struct relu : unary<relu>
{
std::string point_op() const { return "${function:max}(decltype(${0}){0}, ${0})"; }
auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
......
......@@ -71,6 +71,7 @@ struct reshape
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -39,6 +39,7 @@ struct scalar
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -9,6 +9,7 @@ namespace op {
struct sqdiff : binary<sqdiff>
{
std::string point_op() const { return "(${0} - ${1}) * (${0} - ${1})"; }
auto apply() const
{
return [](auto x, auto y) { return (x - y) * (x - y); };
......
......@@ -77,6 +77,7 @@ struct squeeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -18,6 +18,7 @@ namespace op {
struct sub : binary<sub>
{
std::string point_function() const { return "-"; }
auto apply() const
{
return [](auto x, auto y) { return x - y; };
......
......@@ -63,6 +63,7 @@ struct transpose
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
......@@ -14,7 +15,27 @@ namespace op {
template <class Derived>
struct unary : op_name<Derived>
{
value base_attributes() const { return {{"pointwise", true}}; }
std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return pf + "${0}";
}
else
{
return "${function:" + pf + "}(${0})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -12,6 +12,7 @@ namespace op {
struct unary_not : unary<unary_not>
{
std::string point_function() const { return "!"; }
auto apply() const
{
return [](auto x) { return not x; };
......
......@@ -70,6 +70,7 @@ struct unsqueeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -341,6 +341,29 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
auto compile_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
return x.compile(auto_any_cast(ctx), output_shape, input);
}
template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
return value::object{};
}
template <class T>
value compile_op(const T& x,
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
{
return compile_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
value attributes_op(const T&)
{
......@@ -361,6 +384,12 @@ void from_value_op(T& x, const value& v)
return migraphx::from_value(v, x);
}
template <class T>
bool is_borrowed_op(const T&)
{
return false;
}
} // namespace detail
/*
......@@ -372,7 +401,9 @@ void from_value_op(T& x, const value& v)
* bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const;
* bool is_borrowed() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
......@@ -475,12 +506,24 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize();
}
bool is_borrowed() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input);
}
value compile(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compile(ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -569,7 +612,10 @@ struct operation
virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
......@@ -630,6 +676,19 @@ struct operation
return detail::has_finalize_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed())
{
return private_detail_te_self.is_borrowed();
}
template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
{
return detail::is_borrowed_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self,
......@@ -647,6 +706,27 @@ struct operation
return detail::output_alias_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_compile(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compile(ctx, output, input))
{
return private_detail_te_self.compile(ctx, output, input);
}
template <class T>
static value private_detail_te_default_compile(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
return detail::compile_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self,
......@@ -858,12 +938,25 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
bool is_borrowed() const override
{
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{
return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
}
value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
{
return private_detail_te_default_compile(
char(0), private_detail_te_value, ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{
......@@ -1010,6 +1103,24 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
dependent_type<context, Context> ctx2 = std::ref(ctx);
return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(op.compile(ctx, ctx, output_shape, input))
{
return op.compile(ctx, ctx, output_shape, input);
}
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
return op.compute_shape(inputs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment