Commit 3b04798c authored by Paul's avatar Paul
Browse files

Formatting

parent fbcb4570
...@@ -131,7 +131,7 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2) ...@@ -131,7 +131,7 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{ {
buffer[bdim_vec_len][i] = yp[bdim_vec_len][i]; buffer[bdim_vec_len][i] = yp[bdim_vec_len][i];
} }
for(size_t i = idx.local; i < (vec_size-bdim_vec_rem); i += nlocal) for(size_t i = idx.local; i < (vec_size - bdim_vec_rem); i += nlocal)
{ {
buffer[bdim_vec_len][i] = yp[0][i]; buffer[bdim_vec_len][i] = yp[0][i];
} }
...@@ -226,13 +226,14 @@ inline auto nary(argument result, argument arg1, argument arg2) ...@@ -226,13 +226,14 @@ inline auto nary(argument result, argument arg1, argument arg2)
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides(); const auto& strides = arg2.get_shape().strides();
auto stride_it = std::find_if(strides.begin(), auto stride_it = std::find_if(strides.begin(), strides.end(), not_zero);
strides.end(), not_zero);
auto stride_idx = std::distance(strides.begin(), stride_it); auto stride_idx = std::distance(strides.begin(), stride_it);
auto stride_len = arg2.get_shape().lens()[stride_idx]; auto stride_len = arg2.get_shape().lens()[stride_idx];
// TODO: Dont require disibility by 4 // TODO: Dont require disibility by 4
bool divisible_by_4 = (stride_len % 4 == 0) and (arg1.get_shape().elements() % 4 == 0); bool divisible_by_4 = (stride_len % 4 == 0) and (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4 and stride_len <= 2048 and std::none_of(std::next(stride_it), strides.end(), not_zero)) { if(divisible_by_4 and stride_len <= 2048 and
std::none_of(std::next(stride_it), strides.end(), not_zero))
{
binary_broadcast(result, arg1, arg2)(f); binary_broadcast(result, arg1, arg2)(f);
return; return;
} }
......
...@@ -77,7 +77,7 @@ struct auto_print ...@@ -77,7 +77,7 @@ struct auto_print
}; };
std::array<std::function<void()>, 2> auto_print::handlers = {}; std::array<std::function<void()>, 2> auto_print::handlers = {};
template<class T> template <class T>
auto get_hash(const T& x) auto get_hash(const T& x)
{ {
return std::hash<T>{}(x); return std::hash<T>{}(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment