Commit ee29e116 authored by Paul's avatar Paul
Browse files

Do generic nary broadcast

parent c1d244d9
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
template <class R, class...>
struct array_type { using type = R; };
template <class... Ts>
struct array_type<void, Ts...>
: std::common_type<Ts...> {};
template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N>
to_array_impl(T (&a)[N], seq<I...>)
{
return { {a[I]...} };
}
} // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)>
make_array(Ts&&... xs)
{
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))... };
}
constexpr std::array<int, 0> make_array()
{
return {};
}
template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N])
{
return detail::to_array_impl(a, detail::gens<N>{});
}
namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto
rearray_impl(Array a, seq<I...>)
{
return make_array(a[I+Offset]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto pop_front(std::array<T, N> a)
{
return detail::rearray_impl(a, detail::gens<N - 1>{});
}
template <class T, std::size_t N>
constexpr auto pop_back(std::array<T, N> a)
{
return detail::rearray_impl<1>(a, detail::gens<N - 1>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -15,6 +15,12 @@ struct swallow
}
};
template<class T>
auto tuple_size(const T&)
{
return typename std::tuple_size<T>::type{};
}
namespace detail {
template <class R, class F>
......@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return detail::sequence_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
......@@ -95,9 +107,9 @@ constexpr void each_args(F)
}
template <class F, class T>
auto unpack(F f, T& x)
auto unpack(F f, T&& x)
{
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
return sequence(tuple_size(x), [&](auto... is) { f(std::get<is>(static_cast<T&&>(x))...); });
}
/// Implements a fix-point combinator
......@@ -149,6 +161,39 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; };
}
template<class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...)
{
return static_cast<T&&>(x);
}
template<class... Ts>
decltype(auto) back_args(Ts&&... xs)
{
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
}
template<class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs)
{
return [&](auto f) {
f(static_cast<Ts&&>(xs)...);
};
}
template<class... Ts>
auto pop_back_args(Ts&&... xs)
{
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>([&](auto... is) {
f(std::get<is>(static_cast<tuple_type&&>(t))...);
});
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -33,6 +33,8 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x);
}
struct empty {};
} // namespace detail
template <class C, class T>
......@@ -71,6 +73,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p);
}
template <class Predicate>
bool all_of(detail::empty, const Predicate&)
{
return true;
}
template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p)
{
......@@ -83,6 +91,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return std::any_of(c.begin(), c.end(), p);
}
template <class Predicate>
bool any_of(detail::empty, const Predicate&)
{
return false;
}
template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p)
{
......@@ -95,6 +109,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return std::none_of(c.begin(), c.end(), p);
}
template <class Predicate>
bool none_of(detail::empty, const Predicate&)
{
return true;
}
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
......
......@@ -6,6 +6,7 @@
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -257,6 +258,45 @@ void binary_broadcast_impl(
});
}
template <class F, class... Arguments>
void nary_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = binput.data()[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b);
}
});
});
}
template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
......@@ -304,12 +344,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl(stream, f, result, args...);
}
template <class F>
void nary_impl(hipStream_t stream, F f, argument result)
{
nary_standard_impl(stream, f, result);
}
template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{
......@@ -323,71 +357,49 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args)
auto nary(hipStream_t stream, argument result)
{
return [=](auto f) { nary_impl(stream, f, result, args...); };
return [=](auto f) { nary_standard_impl(stream, f, result); };
}
inline auto
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
template <class... Arguments>
auto
nary(hipStream_t stream, argument result, Arguments... args)
{
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
not arg2.get_shape().scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg2.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec_impl(stream, f, result, arg1, arg2);
else
binary_broadcast_impl(stream, f, result, arg1, arg2);
return;
}
}
nary_impl(stream, f, result, arg1, arg2);
};
}
inline auto nary(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().standard() and
arg3.get_shape().broadcasted())
auto barg = back_args(args...);
pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape();
const bool standard = all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes =
all_of({args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and
not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg3.get_shape().strides();
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg3.get_shape().lens()[b_idx] == b_len);
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
trinary_broadcast_vec_impl(stream, f, result, arg1, arg2, arg3);
else
trinary_broadcast_impl(stream, f, result, arg1, arg2, arg3);
return;
nary_broadcast_impl(stream, f, result, barg, args2...);
// const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
// (arg1.get_shape().elements() % 4 == 0);
// if(divisible_by_4)
// binary_broadcast_vec_impl(stream, f, result, arg1, arg);
// else
// binary_broadcast_impl(stream, f, result, arg1, arg);
// return;
}
}
nary_impl(stream, f, result, arg1, arg2, arg3);
});
nary_impl(stream, f, result, args...);
};
}
......
......@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/array.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -122,13 +123,6 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
}
// TODO: Move to another header
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if(ins->name() != "gpu::convolution")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment