Commit 53793762 authored by Paul's avatar Paul
Browse files

Formatting

parent 55422f0e
......@@ -241,7 +241,8 @@ void binary_broadcast_impl(
}
template <class F, class... Arguments>
void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
void nary_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
......@@ -258,7 +259,8 @@ void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
......@@ -417,7 +419,8 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
......
......@@ -58,7 +58,6 @@ struct device_type<half>
using type = gpu_half;
};
template <class T>
struct host_type
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment