Commit 16864eef authored by Paul's avatar Paul
Browse files

Formatting

parent 928cb435
...@@ -119,7 +119,8 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -119,7 +119,8 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args) void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape(); const auto& b_shape = barg1.get_shape();
...@@ -135,30 +136,31 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume ...@@ -135,30 +136,31 @@ void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argume
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements(); std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)([&](auto output, auto binput1, auto binput2, auto... inputs) { hip_visit_all(result, barg1, barg2, args...)(
using type = typename decltype(output)::value_type; [&](auto output, auto binput1, auto binput2, auto... inputs) {
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { using type = typename decltype(output)::value_type;
MIGRAPHX_DEVICE_SHARED type buffer[2048]; launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
// Load bias into LDS MIGRAPHX_DEVICE_SHARED type buffer[2048];
for(size_t i = idx.local; i < bdim_len; i += nlocal) // Load bias into LDS
{ for(size_t i = idx.local; i < bdim_len; i += nlocal)
buffer[i] = binput1.data()[i]; {
} buffer[i] = binput1.data()[i];
for(size_t i = idx.local; i < bdim_len; i += nlocal) }
{ for(size_t i = idx.local; i < bdim_len; i += nlocal)
buffer[i+bdim_len] = binput2.data()[i+bdim_len]; {
} buffer[i + bdim_len] = binput2.data()[i + bdim_len];
__syncthreads(); }
// Process the data __syncthreads();
for(size_t i = idx.global; i < nelements; i += nglobal) // Process the data
{ for(size_t i = idx.global; i < nelements; i += nglobal)
auto bidx = (i % bdim_next_stride) / bdim_stride; {
auto b1 = buffer[bidx]; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b2 = buffer[bidx+bdim_len]; auto b1 = buffer[bidx];
output.data()[i] = f(inputs.data()[i]..., b1, b2); auto b2 = buffer[bidx + bdim_len];
} output.data()[i] = f(inputs.data()[i]..., b1, b2);
}
});
}); });
});
} }
template <class F, class... Arguments> template <class F, class... Arguments>
...@@ -219,15 +221,15 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -219,15 +221,15 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
return [=](auto f) { nary_standard_impl(stream, f, result, args...); }; return [=](auto f) { nary_standard_impl(stream, f, result, args...); };
} }
template<class... Arguments> template <class... Arguments>
bool broadcastable(bool &divisible_by_4, argument result, argument barg, Arguments... args) bool broadcastable(bool& divisible_by_4, argument result, argument barg, Arguments... args)
{ {
divisible_by_4 = false; divisible_by_4 = false;
auto bshape = barg.get_shape(); auto bshape = barg.get_shape();
const bool standard = const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of( const bool same_shapes =
{args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same // TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar()) if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{ {
...@@ -241,16 +243,15 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen ...@@ -241,16 +243,15 @@ bool broadcastable(bool &divisible_by_4, argument result, argument barg, Argumen
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
divisible_by_4 = divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(b_len % 4 == 0) and (b_stride % 4 == 0) and (front_args(args...).get_shape().elements() % 4 == 0);
(front_args(args...).get_shape().elements() % 4 == 0);
return true; return true;
} }
} }
return false; return false;
} }
inline bool broadcastable(bool &divisible_by_4, argument, argument) inline bool broadcastable(bool& divisible_by_4, argument, argument)
{ {
divisible_by_4 = false; divisible_by_4 = false;
return false; return false;
...@@ -265,9 +266,7 @@ inline auto nary(hipStream_t stream, argument result) ...@@ -265,9 +266,7 @@ inline auto nary(hipStream_t stream, argument result)
// Unary // Unary
inline auto nary(hipStream_t stream, argument result, argument arg) inline auto nary(hipStream_t stream, argument result, argument arg)
{ {
return [=](auto f) { return [=](auto f) { nary_impl(stream, f, result, arg); };
nary_impl(stream, f, result, arg);
};
} }
// Binary // Binary
...@@ -275,7 +274,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar ...@@ -275,7 +274,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
{ {
return [=](auto f) { return [=](auto f) {
bool divisible_by_4 = false; bool divisible_by_4 = false;
if (broadcastable(divisible_by_4, result, barg, arg)) if(broadcastable(divisible_by_4, result, barg, arg))
{ {
if(divisible_by_4) if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg); nary_broadcast_vec_impl(stream, f, result, barg, arg);
...@@ -293,10 +292,10 @@ template <class... Arguments> ...@@ -293,10 +292,10 @@ template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
auto barg1 = back_args(args...); auto barg1 = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) { bool fallback = pop_back_args(args...)([&](auto&&... args2) {
bool divisible_by_4 = false; bool divisible_by_4 = false;
if (broadcastable(divisible_by_4, result, barg1, args2...)) if(broadcastable(divisible_by_4, result, barg1, args2...))
{ {
if(divisible_by_4) if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg1, args2...); nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment