Unverified Commit f7befe50 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Cpu fusions using post_ops (#781)



* Add eliminate_data_type pass

* Formatting

* Auto convert quant ops

* Formatting

* Flip the order of decompose

* Compute max size differently

* Formatting

* Clamp values in convert

* Formatting

* Fix loss of precision in reduce

* Formatting

* Fix bugs in reduction

* Fix accumulator type in reference softmax implementation

* Formatting

* Update convert test

* Remove unused variables

* Remove unnecessary quant_dot check

* Formatting

* Add tests

* Formatting

* Remove unused code

* Remove duplicate ops

* Remove blaze dependency

* Use set since shape::type_t is no hashable on gcc 5

* Formatting

* Add dnnl binary op

* Formatting

* Add binary and eltwise

* Formatting

* Add softmax

* Formatting

* Remove unused operators

* Add missing files

* Formatting

* Add lrn

* Formatting

* Add deconvolution

* Formatting

* Change allocate default

* Add reorder

* Formatting

* Add reductions

* Formatting

* Sort lines

* Change literals in another loop

* Add pow operator

* Formatting

* Add pow operator

* Formatting

* Make sure shapes are packed

* Allow broadcasted inputs

* Remove unused operators

* Simplify functions

* Remove softmax

* Add sub and erf functions

* Formatting

* Fix bug

* Formatting

* Improve parallism

* Formatting

* Allow multiple batch dimensions

* Formatting

* Move literal transforms out of lowering

* Formatting

* Add gather operator

* Sort lines

* Add early exit for carry

* Formatting

* Add missing concat

* Rename macro

* Fix deep nesting

* Formatting

* Fix cppcheck issues

* Remov else

* Move attribute to typedef

* Formatting

* Disable maybe-uninitialized warning since its broken on gcc

* Add constexpr default constructor

* Formatting

* Fix compiler warnings

* Fix adjust_allocation test

* Add layernorm matcher

* Add gelu_erf matcher

* Formatting

* Add gelu_tanh matcher

* Formatting

* Remove match namespace

* Formatting

* Use matcher instead of string

* Formatting

* Add fusions

* Formatting

* Add post op field

* Formatting

* Make post_ops serializable

* Formatting

* Add eltwise fusions

* Formatting

* Fix null conversions

* Formatting

* Add fuse_ops source files

* Formatting

* Set binary post op index correctly

* Formatting

* Fix serialization bugs

* Check if used once

* Formatting

* Fix error in get_primitive_attr

* Formatting

* Add compile function

* Formatting

* Limit fusions

* Formatting

* Disable with env variable instead of using compile arg

* Formatting

* Fix implicit conversion to bool

* Declar on seperate lines

* Formatting

* Fix cppcheck issues

* Fix ICE in pack_join

* Formatting

* Use const ref

* Make enum hashable

* Formatting

* Add explicit this

* Fix merge issues

* Fix dangling ref

* Formatting

* Add test for compile

* Formatting

* Add more value tests

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 78eaf2b8
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
#include <migraphx/config.hpp>
#include <cstdlib>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
auto abort_on_throw(F f) -> decltype(f())
{
try
{
return f();
}
catch(const std::exception& e)
{
std::cerr << e.what() << std::endl;
std::abort();
}
catch(...)
{
std::cerr << "Unknown exception" << std::endl;
std::abort();
}
}
#ifdef NDEBUG
#define MIGRAPHX_ASSERT_NO_THROW(...) __VA_ARGS__
#else
#define MIGRAPHX_ASSERT_NO_THROW(...) \
migraphx::abort_on_throw([&]() -> decltype(__VA_ARGS__) { return __VA_ARGS__; })
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
......@@ -125,24 +125,6 @@ auto fix(F f)
return fix<void>(f);
}
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
inline auto pack_join() { return pack(); }
template <class P, class... Ps>
auto pack_join(P p, Ps... ps)
{
return [=](auto f) {
return p([&](auto... xs) {
return pack_join(ps...)([&](auto... ys) { return f(xs..., ys...); });
});
};
}
template <class F, class T>
auto fold_impl(F&&, T&& x)
{
......@@ -161,6 +143,22 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
}
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
inline auto pack_join() { return pack(); }
template <class... Ps>
auto pack_join(Ps... ps)
{
return fold([](auto p1, auto p2) {
return p1([=](auto... xs) { return p2([=](auto... ys) { return pack(xs..., ys...); }); });
})(ps...);
}
template <class F, class Proj>
auto by(F f, Proj proj)
{
......
......@@ -19,6 +19,7 @@ shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs);
struct instruction
{
......
......@@ -183,6 +183,17 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
return {{p}};
}
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
......
......@@ -124,6 +124,7 @@ struct module
std::vector<shape> get_output_shapes() const;
instruction_ref validate() const;
instruction_ref find_dangling_reference() const;
void finalize(context& ctx);
......
......@@ -35,6 +35,7 @@ struct as_shape
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -66,6 +66,7 @@ struct broadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -50,6 +50,7 @@ struct flatten
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -23,6 +23,7 @@ struct identity
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -36,6 +36,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
......
......@@ -68,6 +68,7 @@ struct multibroadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -71,6 +71,7 @@ struct reshape
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -39,6 +39,7 @@ struct scalar
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -77,6 +77,7 @@ struct squeeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -63,6 +63,7 @@ struct transpose
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -70,6 +70,7 @@ struct unsqueeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -341,6 +341,29 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
auto compile_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
return x.compile(auto_any_cast(ctx), output_shape, input);
}
template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
return value::object{};
}
template <class T>
value compile_op(const T& x,
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
{
return compile_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
value attributes_op(const T&)
{
......@@ -361,6 +384,12 @@ void from_value_op(T& x, const value& v)
return migraphx::from_value(v, x);
}
template <class T>
bool is_borrowed_op(const T&)
{
return false;
}
} // namespace detail
/*
......@@ -372,7 +401,9 @@ void from_value_op(T& x, const value& v)
* bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const;
* bool is_borrowed() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
......@@ -475,12 +506,24 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize();
}
bool is_borrowed() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input);
}
value compile(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compile(ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -569,7 +612,10 @@ struct operation
virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
......@@ -630,6 +676,19 @@ struct operation
return detail::has_finalize_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed())
{
return private_detail_te_self.is_borrowed();
}
template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
{
return detail::is_borrowed_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self,
......@@ -647,6 +706,27 @@ struct operation
return detail::output_alias_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_compile(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compile(ctx, output, input))
{
return private_detail_te_self.compile(ctx, output, input);
}
template <class T>
static value private_detail_te_default_compile(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
return detail::compile_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self,
......@@ -858,12 +938,25 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
bool is_borrowed() const override
{
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{
return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
}
value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
{
return private_detail_te_default_compile(
char(0), private_detail_te_value, ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{
......@@ -1010,6 +1103,24 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
dependent_type<context, Context> ctx2 = std::ref(ctx);
return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(op.compile(ctx, ctx, output_shape, input))
{
return op.compile(ctx, ctx, output_shape, input);
}
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
return op.compute_shape(inputs);
......
......@@ -102,6 +102,29 @@ void reflect_each(T& x, F f)
});
}
template <class T>
struct reflect_equality
{
friend bool operator==(const T& x, const T& y) { return reflect_tie(x) == reflect_tie(y); }
friend bool operator!=(const T& x, const T& y) { return !(x == y); }
};
template <class T>
struct reflect_stream
{
template <class Stream>
friend Stream& operator<<(Stream& os, const T& x)
{
char d = '{';
reflect_each(x, [&](const auto& y, const auto& name) {
os << d << name << "=" << y;
d = ',';
});
os << "}";
return os;
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
File mode changed from 100644 to 100755
......@@ -7,6 +7,7 @@
#include <migraphx/type_name.hpp>
#include <migraphx/rank.hpp>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <type_traits>
......@@ -58,8 +59,6 @@ struct value_converter<std::string>
{
static const std::string& apply(const std::string& x) { return x; }
static std::string apply(const std::nullptr_t&) { return "null"; }
template <class From>
static auto apply(const From& x)
-> decltype(std::declval<std::stringstream&>() << x, std::string())
......@@ -83,20 +82,28 @@ struct value_converter<std::pair<T, U>>
}
};
template <class To, class From>
To try_convert_value(const From& x);
namespace detail {
template <class To, class Key, class From>
auto try_convert_value_impl(rank<2>, const std::pair<Key, From>& x)
-> decltype(value_converter<To>::apply(x.second))
To try_convert_value_impl(rank<1>, const std::pair<Key, From>& x)
{
return value_converter<To>::apply(x.second);
return try_convert_value<To>(x.second);
}
template <class To, class From>
auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x))
auto try_convert_value_impl(rank<2>, const From& x) -> decltype(value_converter<To>::apply(x))
{
return value_converter<To>::apply(x);
}
template <class To, MIGRAPHX_REQUIRES(not std::is_same<To, std::nullptr_t>{})>
To try_convert_value_impl(rank<3>, std::nullptr_t)
{
MIGRAPHX_THROW("Incompatible values: null -> " + get_type_name<To>());
}
template <class To, class From>
To try_convert_value_impl(rank<0>, const From& x)
{
......@@ -107,7 +114,7 @@ To try_convert_value_impl(rank<0>, const From& x)
template <class To, class From>
To try_convert_value(const From& x)
{
return detail::try_convert_value_impl<To>(rank<2>{}, x);
return detail::try_convert_value_impl<To>(rank<3>{}, x);
}
struct value
......@@ -309,7 +316,11 @@ struct value
{
case null_type:
{
v(std::nullptr_t{});
std::nullptr_t null{};
if(this->key.empty())
v(null);
else
v(std::make_pair(this->get_key(), std::ref(null)));
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
......@@ -328,6 +339,31 @@ struct value
MIGRAPHX_THROW("Unknown type");
}
// Visit value without key
template <class Visitor>
void visit_value(Visitor v) const
{
switch(this->get_type())
{
case null_type:
{
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \
{ \
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
MIGRAPHX_THROW("Unknown type");
}
template <class To>
To to() const
{
......@@ -336,6 +372,14 @@ struct value
return result;
}
template <class To>
To value_or(const To& default_value) const
{
if(this->is_null())
return default_value;
return to<To>();
}
template <class To>
std::vector<To> to_vector() const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment