Commit ee29e116 authored by Paul's avatar Paul
Browse files

Do generic nary broadcast

parent c1d244d9
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
template <class R, class...>
struct array_type { using type = R; };
template <class... Ts>
struct array_type<void, Ts...>
: std::common_type<Ts...> {};
template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N>
to_array_impl(T (&a)[N], seq<I...>)
{
return { {a[I]...} };
}
} // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)>
make_array(Ts&&... xs)
{
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))... };
}
constexpr std::array<int, 0> make_array()
{
return {};
}
template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N])
{
return detail::to_array_impl(a, detail::gens<N>{});
}
namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto
rearray_impl(Array a, seq<I...>)
{
return make_array(a[I+Offset]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto pop_front(std::array<T, N> a)
{
return detail::rearray_impl(a, detail::gens<N - 1>{});
}
template <class T, std::size_t N>
constexpr auto pop_back(std::array<T, N> a)
{
return detail::rearray_impl<1>(a, detail::gens<N - 1>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -15,6 +15,12 @@ struct swallow ...@@ -15,6 +15,12 @@ struct swallow
} }
}; };
template<class T>
auto tuple_size(const T&)
{
return typename std::tuple_size<T>::type{};
}
namespace detail { namespace detail {
template <class R, class F> template <class R, class F>
...@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f) ...@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return detail::sequence_c_impl(f, detail::gens<N>{}); return detail::sequence_c_impl(f, detail::gens<N>{});
} }
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class... Ts> template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs) constexpr void each_args(F f, Ts&&... xs)
{ {
...@@ -95,9 +107,9 @@ constexpr void each_args(F) ...@@ -95,9 +107,9 @@ constexpr void each_args(F)
} }
template <class F, class T> template <class F, class T>
auto unpack(F f, T& x) auto unpack(F f, T&& x)
{ {
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); }); return sequence(tuple_size(x), [&](auto... is) { f(std::get<is>(static_cast<T&&>(x))...); });
} }
/// Implements a fix-point combinator /// Implements a fix-point combinator
...@@ -149,6 +161,39 @@ auto index_of(T& x) ...@@ -149,6 +161,39 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; }; return [&](auto&& y) { return x[y]; };
} }
template<class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...)
{
return static_cast<T&&>(x);
}
template<class... Ts>
decltype(auto) back_args(Ts&&... xs)
{
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
}
template<class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs)
{
return [&](auto f) {
f(static_cast<Ts&&>(xs)...);
};
}
template<class... Ts>
auto pop_back_args(Ts&&... xs)
{
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>([&](auto... is) {
f(std::get<is>(static_cast<tuple_type&&>(t))...);
});
};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -33,6 +33,8 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -33,6 +33,8 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
struct empty {};
} // namespace detail } // namespace detail
template <class C, class T> template <class C, class T>
...@@ -71,6 +73,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -71,6 +73,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p); return std::all_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool all_of(detail::empty, const Predicate&)
{
return true;
}
template <class C, class Predicate> template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p) bool any_of(const C& c, const Predicate& p)
{ {
...@@ -83,6 +91,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -83,6 +91,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return std::any_of(c.begin(), c.end(), p); return std::any_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool any_of(detail::empty, const Predicate&)
{
return false;
}
template <class C, class Predicate> template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p) bool none_of(const C& c, const Predicate& p)
{ {
...@@ -95,6 +109,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -95,6 +109,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return std::none_of(c.begin(), c.end(), p); return std::none_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool none_of(detail::empty, const Predicate&)
{
return true;
}
template <class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -257,6 +258,45 @@ void binary_broadcast_impl( ...@@ -257,6 +258,45 @@ void binary_broadcast_impl(
}); });
} }
template <class F, class... Arguments>
void nary_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = binput.data()[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b);
}
});
});
}
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
...@@ -304,12 +344,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args) ...@@ -304,12 +344,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl(stream, f, result, args...); nary_nonstandard_impl(stream, f, result, args...);
} }
template <class F>
void nary_impl(hipStream_t stream, F f, argument result)
{
nary_standard_impl(stream, f, result);
}
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args) auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{ {
...@@ -323,71 +357,49 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -323,71 +357,49 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result)
{ {
return [=](auto f) { nary_impl(stream, f, result, args...); }; return [=](auto f) { nary_standard_impl(stream, f, result); };
} }
inline auto template <class... Arguments>
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) auto
nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
not arg2.get_shape().scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg2.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec_impl(stream, f, result, arg1, arg2);
else
binary_broadcast_impl(stream, f, result, arg1, arg2);
return;
}
}
nary_impl(stream, f, result, arg1, arg2);
};
}
inline auto nary(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same auto barg = back_args(args...);
if(arg1.get_shape().standard() and arg2.get_shape().standard() and pop_back_args(args...)([&](auto&&... args2) {
arg3.get_shape().broadcasted()) auto bshape = barg.get_shape();
{ const bool standard = all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
auto not_zero = [](auto x) { return x != 0; }; const bool same_shapes =
const auto& strides = arg3.get_shape().strides(); all_of({args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); // TODO: Check result and args shape is the same
auto b_idx = std::distance(strides.begin(), b_it); if(standard and same_shapes and bshape.broadcasted() and
auto b_len = result.get_shape().lens()[b_idx]; not bshape.scalar())
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg3.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and auto not_zero = [](auto x) { return x != 0; };
(arg1.get_shape().elements() % 4 == 0); const auto& strides = bshape.strides();
if(divisible_by_4) auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
trinary_broadcast_vec_impl(stream, f, result, arg1, arg2, arg3); auto b_idx = std::distance(strides.begin(), b_it);
else auto b_len = result.get_shape().lens()[b_idx];
trinary_broadcast_impl(stream, f, result, arg1, arg2, arg3); auto b_stride = result.get_shape().strides()[b_idx];
return; assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
nary_broadcast_impl(stream, f, result, barg, args2...);
// const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
// (arg1.get_shape().elements() % 4 == 0);
// if(divisible_by_4)
// binary_broadcast_vec_impl(stream, f, result, arg1, arg);
// else
// binary_broadcast_impl(stream, f, result, arg1, arg);
// return;
}
} }
} });
nary_impl(stream, f, result, arg1, arg2, arg3); nary_impl(stream, f, result, args...);
}; };
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/array.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) ...@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0; s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
} }
// TODO: Move to another header
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
if(ins->name() != "gpu::convolution") if(ins->name() != "gpu::convolution")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment