"vscode:/vscode.git/clone" did not exist on "99ebfebad1b2eb22e710a97ab9ef5bd2b6bb443e"
Commit 53793762 authored by Paul's avatar Paul
Browse files

Formatting

parent 55422f0e
...@@ -241,7 +241,8 @@ void binary_broadcast_impl( ...@@ -241,7 +241,8 @@ void binary_broadcast_impl(
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args) void nary_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape(); const auto& b_shape = barg.get_shape();
...@@ -258,35 +259,36 @@ void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument ...@@ -258,35 +259,36 @@ void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)([&](auto output, auto binput, auto... inputs) { hip_vec_visit_all<vec_size>(result, barg, args...)(
using type = typename decltype(output)::value_type; [&](auto output, auto binput, auto... inputs) {
const std::size_t nelements = output.size() / vec_size; using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) // Load bias into LDS
{ for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
buffer[i] = binput.data()[i]; {
} buffer[i] = binput.data()[i];
__syncthreads(); }
auto* bp = as_pointer(buffer); __syncthreads();
// Process the data auto* bp = as_pointer(buffer);
for(size_t i = idx.global; i < nelements; i += nglobal) // Process the data
{ for(size_t i = idx.global; i < nelements; i += nglobal)
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; {
auto b = bp[bidx]; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto out = output.data()[i]; auto b = bp[bidx];
pack(inputs.data()[i]...)([&](auto... xs) __device__ { auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++) pack(inputs.data()[i]...)([&](auto... xs) __device__ {
{ for(std::size_t j = 0; j < vec_size; j++)
output.data()[i][j] = f(xs[j]..., b); {
} output.data()[i][j] = f(xs[j]..., b);
}); }
output.data()[i] = out; });
} output.data()[i] = out;
}
});
}); });
});
} }
template <class F, class... Arguments> template <class F, class... Arguments>
...@@ -417,8 +419,9 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -417,8 +419,9 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and const bool divisible_by_4 =
(front_args(args...).get_shape().elements() % 4 == 0); (b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...); nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else else
......
...@@ -58,7 +58,6 @@ struct device_type<half> ...@@ -58,7 +58,6 @@ struct device_type<half>
using type = gpu_half; using type = gpu_half;
}; };
template <class T> template <class T>
struct host_type struct host_type
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment