Commit d992494e authored by Paul's avatar Paul
Browse files

Formatting

parent 6459204c
...@@ -61,14 +61,14 @@ inline auto binary_broadcast_vec(argument result, argument arg1, argument arg2) ...@@ -61,14 +61,14 @@ inline auto binary_broadcast_vec(argument result, argument arg1, argument arg2)
{ {
return [=](auto f) { return [=](auto f) {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(),
b_shape.strides().end(), b_shape.strides().end(),
[](auto x) { return x != 0; })); [](auto x) { return x != 0; }));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
...@@ -90,11 +90,11 @@ inline auto binary_broadcast_vec(argument result, argument arg1, argument arg2) ...@@ -90,11 +90,11 @@ inline auto binary_broadcast_vec(argument result, argument arg1, argument arg2)
buffer[i] = yp[i]; buffer[i] = yp[i];
} }
__syncthreads(); __syncthreads();
auto * bp = as_pointer(buffer); auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx]; auto b = bp[bidx];
vec4<type> x = xp[i]; vec4<type> x = xp[i];
vec4<type> out = outp[i]; vec4<type> out = outp[i];
...@@ -113,14 +113,14 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -113,14 +113,14 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{ {
return [=](auto f) { return [=](auto f) {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(),
b_shape.strides().end(), b_shape.strides().end(),
[](auto x) { return x != 0; })); [](auto x) { return x != 0; }));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
...@@ -225,15 +225,15 @@ inline auto nary(argument result, argument arg1, argument arg2) ...@@ -225,15 +225,15 @@ inline auto nary(argument result, argument arg1, argument arg2)
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides(); const auto& strides = arg2.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it); auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx]; auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx]; auto b_stride = result.get_shape().strides()[b_idx];
assert(arg2.get_shape().lens()[b_idx] == b_len); assert(arg2.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and (arg1.get_shape().elements() % 4 == 0); const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
binary_broadcast_vec(result, arg1, arg2)(f); binary_broadcast_vec(result, arg1, arg2)(f);
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment