Commit 8ec57ece authored by Paul's avatar Paul
Browse files

Formatting

parent ee29e116
......@@ -13,34 +13,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
template <class R, class...>
struct array_type { using type = R; };
struct array_type
{
using type = R;
};
template <class... Ts>
struct array_type<void, Ts...>
: std::common_type<Ts...> {};
struct array_type<void, Ts...> : std::common_type<Ts...>
{
};
template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N>
to_array_impl(T (&a)[N], seq<I...>)
constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], seq<I...>)
{
return { {a[I]...} };
return {{a[I]...}};
}
} // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)>
make_array(Ts&&... xs)
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> make_array(Ts&&... xs)
{
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))... };
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))...};
}
constexpr std::array<int, 0> make_array()
{
return {};
}
constexpr std::array<int, 0> make_array() { return {}; }
template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N])
......@@ -51,10 +50,9 @@ constexpr auto to_array(T (&a)[N])
namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto
rearray_impl(Array a, seq<I...>)
constexpr auto rearray_impl(Array a, seq<I...>)
{
return make_array(a[I+Offset]...);
return make_array(a[I + Offset]...);
}
} // namespace detail
......
......@@ -15,7 +15,7 @@ struct swallow
}
};
template<class T>
template <class T>
auto tuple_size(const T&)
{
return typename std::tuple_size<T>::type{};
......@@ -161,39 +161,35 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; };
}
template<class T, class... Ts>
template <class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...)
{
return static_cast<T&&>(x);
}
template<class... Ts>
template <class... Ts>
decltype(auto) back_args(Ts&&... xs)
{
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
}
template<class T, class... Ts>
template <class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs)
{
return [&](auto f) {
f(static_cast<Ts&&>(xs)...);
};
return [&](auto f) { f(static_cast<Ts&&>(xs)...); };
}
template<class... Ts>
template <class... Ts>
auto pop_back_args(Ts&&... xs)
{
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>([&](auto... is) {
f(std::get<is>(static_cast<tuple_type&&>(t))...);
});
auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { f(std::get<is>(static_cast<tuple_type&&>(t))...); });
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -33,7 +33,9 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x);
}
struct empty {};
struct empty
{
};
} // namespace detail
......
......@@ -259,8 +259,7 @@ void binary_broadcast_impl(
}
template <class F, class... Arguments>
void nary_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
......@@ -275,7 +274,7 @@ void nary_broadcast_impl(
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
......@@ -289,9 +288,9 @@ void nary_broadcast_impl(
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b);
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b);
}
});
});
......@@ -363,20 +362,19 @@ auto nary(hipStream_t stream, argument result)
}
template <class... Arguments>
auto
nary(hipStream_t stream, argument result, Arguments... args)
auto nary(hipStream_t stream, argument result, Arguments... args)
{
return [=](auto f) {
auto barg = back_args(args...);
pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape();
const bool standard = all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes =
all_of({args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
const bool standard =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of(
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and
not bshape.scalar())
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment