Commit 16864eef authored by Paul's avatar Paul
Browse files

Formatting

parent 928cb435
......@@ -119,7 +119,8 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
}
template <class F, class... Arguments>
void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
......@@ -135,7 +136,8 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)([&](auto output, auto binput1, auto binput2, auto... inputs) {
hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
......@@ -146,7 +148,7 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
}
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i+bdim_len] = binput2.data()[i+bdim_len];
buffer[i + bdim_len] = binput2.data()[i + bdim_len];
}
__syncthreads();
// Process the data
......@@ -154,7 +156,7 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b1 = buffer[bidx];
auto b2 = buffer[bidx+bdim_len];
auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b1, b2);
}
});
......@@ -219,15 +221,15 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
return [=](auto f) { nary_standard_impl(stream, f, result, args...); };
}
template<class... Arguments>
bool broadcastable(bool &divisible_by_4, argument result, argument barg, Arguments... args)
template <class... Arguments>
bool broadcastable(bool& divisible_by_4, argument result, argument barg, Arguments... args)
{
divisible_by_4 = false;
auto bshape = barg.get_shape();
const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of(
{args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
const bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
......@@ -241,8 +243,7 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
return true;
}
......@@ -250,7 +251,7 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
return false;
}
inline bool broadcastable(bool &divisible_by_4, argument, argument)
inline bool broadcastable(bool& divisible_by_4, argument, argument)
{
divisible_by_4 = false;
return false;
......@@ -265,9 +266,7 @@ inline auto nary(hipStream_t stream, argument result)
// Unary
inline auto nary(hipStream_t stream, argument result, argument arg)
{
return [=](auto f) {
nary_impl(stream, f, result, arg);
};
return [=](auto f) { nary_impl(stream, f, result, arg); };
}
// Binary
......@@ -275,7 +274,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
{
return [=](auto f) {
bool divisible_by_4 = false;
if (broadcastable(divisible_by_4, result, barg, arg))
if(broadcastable(divisible_by_4, result, barg, arg))
{
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg);
......@@ -296,7 +295,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
auto barg1 = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) {
bool divisible_by_4 = false;
if (broadcastable(divisible_by_4, result, barg1, args2...))
if(broadcastable(divisible_by_4, result, barg1, args2...))
{
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment