Commit 60c6738a authored by Paul's avatar Paul
Browse files

Formatting

parent a7b934bf
...@@ -68,10 +68,10 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -68,10 +68,10 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
auto* yp = as_vec4(input2.data()); auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(output.data());
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size; const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(nglobal, nlocal)([=](auto idx) __device__ {
...@@ -83,13 +83,13 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -83,13 +83,13 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
__syncthreads(); __syncthreads();
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
auto bidx = i % bdim_vec_len; auto bidx = i % bdim_vec_len;
auto b = buffer[bidx]; auto b = buffer[bidx];
vec4<type> x = xp[i]; vec4<type> x = xp[i];
vec4<type> out = outp[i]; vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], b[j]); out[j] = f(x[j], b[j]);
} }
outp[i] = out; outp[i] = out;
} }
...@@ -111,17 +111,20 @@ auto nary_standard(argument result, Arguments... args) ...@@ -111,17 +111,20 @@ auto nary_standard(argument result, Arguments... args)
gs_launch(output_shape.elements())( gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
#else #else
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...); auto data = pack_vec4(inputs.data()...);
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(output.data());
gs_launch(output_shape.elements() / vec_size)([=](auto i) { gs_launch(output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec4<type> out = outp[i];
data([&](auto... xs) { data(
for(std::size_t j = 0; j < vec_size; j++) { [&](auto... xs) {
out[j] = f(xs[j]...); for(std::size_t j = 0; j < vec_size; j++)
} {
}, i); out[j] = f(xs[j]...);
}
},
i);
outp[i] = out; outp[i] = out;
}); });
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment