Commit 8ec57ece authored by Paul's avatar Paul
Browse files

Formatting

parent ee29e116
...@@ -13,34 +13,33 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,34 +13,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
template <class R, class...> template <class R, class...>
struct array_type { using type = R; }; struct array_type
{
using type = R;
};
template <class... Ts> template <class... Ts>
struct array_type<void, Ts...> struct array_type<void, Ts...> : std::common_type<Ts...>
: std::common_type<Ts...> {}; {
};
template <class R, class... Ts> template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type; using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I> template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N> constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], seq<I...>)
to_array_impl(T (&a)[N], seq<I...>)
{ {
return { {a[I]...} }; return {{a[I]...}};
} }
} // namespace detail } // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))> template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> make_array(Ts&&... xs)
make_array(Ts&&... xs)
{ {
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))... }; return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))...};
} }
constexpr std::array<int, 0> make_array() constexpr std::array<int, 0> make_array() { return {}; }
{
return {};
}
template <class T, std::size_t N> template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N]) constexpr auto to_array(T (&a)[N])
...@@ -51,10 +50,9 @@ constexpr auto to_array(T (&a)[N]) ...@@ -51,10 +50,9 @@ constexpr auto to_array(T (&a)[N])
namespace detail { namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I> template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto constexpr auto rearray_impl(Array a, seq<I...>)
rearray_impl(Array a, seq<I...>)
{ {
return make_array(a[I+Offset]...); return make_array(a[I + Offset]...);
} }
} // namespace detail } // namespace detail
......
...@@ -15,7 +15,7 @@ struct swallow ...@@ -15,7 +15,7 @@ struct swallow
} }
}; };
template<class T> template <class T>
auto tuple_size(const T&) auto tuple_size(const T&)
{ {
return typename std::tuple_size<T>::type{}; return typename std::tuple_size<T>::type{};
...@@ -161,39 +161,35 @@ auto index_of(T& x) ...@@ -161,39 +161,35 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; }; return [&](auto&& y) { return x[y]; };
} }
template<class T, class... Ts> template <class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...) decltype(auto) front_args(T&& x, Ts&&...)
{ {
return static_cast<T&&>(x); return static_cast<T&&>(x);
} }
template<class... Ts> template <class... Ts>
decltype(auto) back_args(Ts&&... xs) decltype(auto) back_args(Ts&&... xs)
{ {
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...)); return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
} }
template<class T, class... Ts> template <class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs) auto pop_front_args(T&&, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) { f(static_cast<Ts&&>(xs)...); };
f(static_cast<Ts&&>(xs)...);
};
} }
template<class... Ts> template <class... Ts>
auto pop_back_args(Ts&&... xs) auto pop_back_args(Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>; using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...); auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>([&](auto... is) { sequence_c<sizeof...(Ts) - 1>(
f(std::get<is>(static_cast<tuple_type&&>(t))...); [&](auto... is) { f(std::get<is>(static_cast<tuple_type&&>(t))...); });
});
}; };
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -33,7 +33,9 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -33,7 +33,9 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
struct empty {}; struct empty
{
};
} // namespace detail } // namespace detail
......
...@@ -259,8 +259,7 @@ void binary_broadcast_impl( ...@@ -259,8 +259,7 @@ void binary_broadcast_impl(
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_broadcast_impl( void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape(); const auto& b_shape = barg.get_shape();
...@@ -363,20 +362,19 @@ auto nary(hipStream_t stream, argument result) ...@@ -363,20 +362,19 @@ auto nary(hipStream_t stream, argument result)
} }
template <class... Arguments> template <class... Arguments>
auto auto nary(hipStream_t stream, argument result, Arguments... args)
nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
auto barg = back_args(args...); auto barg = back_args(args...);
pop_back_args(args...)([&](auto&&... args2) { pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape(); auto bshape = barg.get_shape();
const bool standard = all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); }); const bool standard =
const bool same_shapes = all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
all_of({args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); const bool same_shapes = all_of(
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same // TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
not bshape.scalar())
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides(); const auto& strides = bshape.strides();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment