Commit c97080ce authored by Paul's avatar Paul
Browse files

Fuse transpose

parent f17d6246
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <type_traits> #include <type_traits>
...@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x) ...@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return result; return result;
} }
template <class T>
auto to_value_impl(rank<4>, const optional<T>& x)
{
value result{};
if (x.has_value())
to_value(*x);
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x) value to_value_impl(rank<5>, const T& x)
{ {
return std::int64_t{x}; return std::int64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x) value to_value_impl(rank<6>, const T& x)
{ {
return std::uint64_t{x}; return std::uint64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x) value to_value_impl(rank<7>, const T& x)
{ {
return double{x}; return double{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x) value to_value_impl(rank<8>, const T& x)
{ {
return x; return x;
} }
inline value to_value_impl(rank<8>, const std::string& x) { return x; } inline value to_value_impl(rank<9>, const std::string& x) { return x; }
template <class T> template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x)) auto to_value_impl(rank<10>, const T& x) -> decltype(migraphx_to_value(x))
{ {
return migraphx_to_value(x); return migraphx_to_value(x);
} }
template <class T> template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value()) auto to_value_impl(rank<11>, const T& x) -> decltype(x.to_value())
{ {
return x.to_value(); return x.to_value();
} }
template <class T> template <class T>
auto to_value_impl(rank<11>, const T& x) auto to_value_impl(rank<12>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{}) -> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{ {
value v; value v;
...@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x) ...@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x)
}); });
} }
template <class T>
void from_value_impl(rank<6>, const value& v, optional<T>& x)
{
if (not v.is_null())
x = from_value<T>(v);
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<6>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<7>, const value& v, T& x) void from_value_impl(rank<8>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); } inline void from_value_impl(rank<9>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<11>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v ...@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
template <class T> template <class T>
value to_value(const T& x) value to_value(const T& x)
{ {
return detail::to_value_impl(rank<11>{}, x); return detail::to_value_impl(rank<12>{}, x);
} }
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<10>{}, v, x); detail::from_value_impl(rank<11>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector> #include <vector>
...@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) ...@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
os << "}"; os << "}";
} }
template <class T>
void stream_write_value_impl(rank<0>, std::ostream& os, const optional<T>& x)
{
if (x.has_value())
stream_write_value_impl(rank<2>{}, os, *x);
else
os << "none";
}
} // namespace detail } // namespace detail
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<2>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -42,13 +42,15 @@ struct precompile_op ...@@ -42,13 +42,15 @@ struct precompile_op
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1; std::size_t additional_args = 1;
bool ignore_modules = false; bool ignore_modules = false;
optional<shape> output_shape = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op"), return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"), f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules")); f(self.ignore_modules, "ignore_modules"),
f(self.output_shape, "output_shape"));
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
...@@ -57,9 +59,14 @@ struct precompile_op ...@@ -57,9 +59,14 @@ struct precompile_op
{ {
// Pop off additional args // Pop off additional args
inputs.resize(inputs.size() - additional_args); inputs.resize(inputs.size() - additional_args);
shape r{};
if(ignore_modules) if(ignore_modules)
return op.compute_shape(inputs); r = op.compute_shape(inputs);
return op.compute_shape(inputs, mods); else
r = op.compute_shape(inputs, mods);
if (output_shape.has_value())
r = *output_shape;
return r;
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
......
...@@ -650,6 +650,32 @@ struct find_gemm_pointwise ...@@ -650,6 +650,32 @@ struct find_gemm_pointwise
} }
}; };
struct find_contiguous_tranpose_precompile
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op")))
.bind("transpose")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto op_ins = r.instructions["op"];
auto alloc = op_ins->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
auto s = shape::from_permutation(op_ins->get_shape().type(), op_ins->get_shape().lens(), iperm);
auto v = op_ins->get_operator().to_value();
v["output_shape"] = to_value(s);
auto new_op = make_op("gpu::precompile_op", v);
m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs());
}
};
struct find_contiguous_tranpose_gemm struct find_contiguous_tranpose_gemm
{ {
auto matcher() const auto matcher() const
...@@ -825,6 +851,7 @@ void fuse_ops::apply(module& m) const ...@@ -825,6 +851,7 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise{}, find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_contiguous_tranpose_precompile{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment