Commit 928cb435 authored by Paul's avatar Paul
Browse files

Refactor nary

parent 784dc2aa
......@@ -118,6 +118,49 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
}
template <class F, class... Arguments>
void nary_double_broadcast_impl(hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)([&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i+bdim_len] = binput2.data()[i+bdim_len];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b1 = buffer[bidx];
auto b2 = buffer[bidx+bdim_len];
output.data()[i] = f(inputs.data()[i]..., b1, b2);
}
});
});
}
template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
......@@ -176,46 +219,90 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
return [=](auto f) { nary_standard_impl(stream, f, result, args...); };
}
template <class... Arguments>
auto nary(hipStream_t stream, argument result)
template<class... Arguments>
bool broadcastable(bool &divisible_by_4, argument result, argument barg, Arguments... args)
{
divisible_by_4 = false;
auto bshape = barg.get_shape();
const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of(
{args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
return true;
}
}
return false;
}
inline bool broadcastable(bool &divisible_by_4, argument, argument)
{
divisible_by_4 = false;
return false;
}
// Nullary
inline auto nary(hipStream_t stream, argument result)
{
return [=](auto f) { nary_standard_impl(stream, f, result); };
}
// Unary
inline auto nary(hipStream_t stream, argument result, argument arg)
{
return [=](auto f) {
nary_impl(stream, f, result, arg);
};
}
// Binary
inline auto nary(hipStream_t stream, argument result, argument arg, argument barg)
{
return [=](auto f) {
bool divisible_by_4 = false;
if (broadcastable(divisible_by_4, result, barg, arg))
{
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg);
else
nary_broadcast_impl(stream, f, result, barg, arg);
}
else
{
nary_impl(stream, f, result, arg, barg);
}
};
}
template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args)
{
return [=](auto f) {
auto barg = back_args(args...);
auto barg1 = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape();
const bool standard =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of(
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
bool divisible_by_4 = false;
if (broadcastable(divisible_by_4, result, barg1, args2...))
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else
nary_broadcast_impl(stream, f, result, barg, args2...);
return false;
}
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
else
nary_broadcast_impl(stream, f, result, barg1, args2...);
return false;
}
return true;
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment