Commit c97080ce authored by Paul's avatar Paul
Browse files

Fuse transpose

parent f17d6246
......@@ -28,6 +28,7 @@
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
......@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return result;
}
template <class T>
auto to_value_impl(rank<4>, const optional<T>& x)
{
value result{};
if (x.has_value())
to_value(*x);
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x)
value to_value_impl(rank<5>, const T& x)
{
return std::int64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x)
value to_value_impl(rank<6>, const T& x)
{
return std::uint64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x)
value to_value_impl(rank<7>, const T& x)
{
return double{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x)
value to_value_impl(rank<8>, const T& x)
{
return x;
}
inline value to_value_impl(rank<8>, const std::string& x) { return x; }
inline value to_value_impl(rank<9>, const std::string& x) { return x; }
template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x))
auto to_value_impl(rank<10>, const T& x) -> decltype(migraphx_to_value(x))
{
return migraphx_to_value(x);
}
template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value())
auto to_value_impl(rank<11>, const T& x) -> decltype(x.to_value())
{
return x.to_value();
}
template <class T>
auto to_value_impl(rank<11>, const T& x)
auto to_value_impl(rank<12>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{
value v;
......@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x)
});
}
template <class T>
void from_value_impl(rank<6>, const value& v, optional<T>& x)
{
if (not v.is_null())
x = from_value<T>(v);
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<6>, const value& v, T& x)
void from_value_impl(rank<7>, const value& v, T& x)
{
x = v.to<T>();
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<7>, const value& v, T& x)
void from_value_impl(rank<8>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); }
inline void from_value_impl(rank<9>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void())
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(x.from_value(v), void())
{
x.from_value(v);
}
template <class T>
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
auto from_value_impl(rank<11>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{
migraphx_from_value(v, x);
}
......@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
template <class T>
value to_value(const T& x)
{
return detail::to_value_impl(rank<11>{}, x);
return detail::to_value_impl(rank<12>{}, x);
}
template <class T>
void from_value(const value& v, T& x)
{
detail::from_value_impl(rank<10>{}, v, x);
detail::from_value_impl(rank<11>{}, v, x);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -29,6 +29,7 @@
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/config.hpp>
#include <vector>
......@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
os << "}";
}
template <class T>
void stream_write_value_impl(rank<0>, std::ostream& os, const optional<T>& x)
{
if (x.has_value())
stream_write_value_impl(rank<2>{}, os, *x);
else
os << "none";
}
} // namespace detail
template <class T>
void stream_write_value(std::ostream& os, const T& x)
{
detail::stream_write_value_impl(rank<1>{}, os, x);
detail::stream_write_value_impl(rank<2>{}, os, x);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -42,13 +42,15 @@ struct precompile_op
operation op = op::identity{};
std::size_t additional_args = 1;
bool ignore_modules = false;
optional<shape> output_shape = {};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules"));
f(self.ignore_modules, "ignore_modules"),
f(self.output_shape, "output_shape"));
}
std::string name() const { return "gpu::precompile_op"; }
......@@ -57,9 +59,14 @@ struct precompile_op
{
// Pop off additional args
inputs.resize(inputs.size() - additional_args);
shape r{};
if(ignore_modules)
return op.compute_shape(inputs);
return op.compute_shape(inputs, mods);
r = op.compute_shape(inputs);
else
r = op.compute_shape(inputs, mods);
if (output_shape.has_value())
r = *output_shape;
return r;
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
......
......@@ -650,6 +650,32 @@ struct find_gemm_pointwise
}
};
struct find_contiguous_tranpose_precompile
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op")))
.bind("transpose")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto op_ins = r.instructions["op"];
auto alloc = op_ins->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
auto s = shape::from_permutation(op_ins->get_shape().type(), op_ins->get_shape().lens(), iperm);
auto v = op_ins->get_operator().to_value();
v["output_shape"] = to_value(s);
auto new_op = make_op("gpu::precompile_op", v);
m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs());
}
};
struct find_contiguous_tranpose_gemm
{
auto matcher() const
......@@ -825,6 +851,7 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise{},
find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_contiguous_tranpose_precompile{},
find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment