Commit bb666690 authored by Paul's avatar Paul
Browse files

Fix clang tidy complexity issue

parent 26a78750
...@@ -51,15 +51,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -51,15 +51,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
}); });
} }
template <class... Arguments> template<class F>
auto nary_nonstandard(argument result, Arguments... args) void binary_broadcast_vec_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{
return [=](auto f) { return nary_nonstandard_impl(f, result, args...); };
}
inline auto binary_broadcast_vec(const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
...@@ -106,12 +100,11 @@ inline auto binary_broadcast_vec(const argument& result, const argument& arg1, c ...@@ -106,12 +100,11 @@ inline auto binary_broadcast_vec(const argument& result, const argument& arg1, c
} }
}); });
}); });
};
} }
inline auto binary_broadcast(const argument& result, const argument& arg1, const argument& arg2) template<class F>
void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
...@@ -150,13 +143,11 @@ inline auto binary_broadcast(const argument& result, const argument& arg1, const ...@@ -150,13 +143,11 @@ inline auto binary_broadcast(const argument& result, const argument& arg1, const
} }
}); });
}); });
};
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_standard_vec(argument result, Arguments... args) void nary_standard_vec_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
...@@ -177,13 +168,11 @@ auto nary_standard_vec(argument result, Arguments... args) ...@@ -177,13 +168,11 @@ auto nary_standard_vec(argument result, Arguments... args)
outp[i] = out; outp[i] = out;
}); });
}); });
};
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_standard(argument result, Arguments... args) void nary_standard_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
...@@ -192,29 +181,44 @@ auto nary_standard(argument result, Arguments... args) ...@@ -192,29 +181,44 @@ auto nary_standard(argument result, Arguments... args)
gs_launch(output_shape.elements())( gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
};
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_impl(argument result, Arguments... args) void nary_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); }); bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes)) if(standard or (packed and same_shapes))
nary_standard(result, args...)(f); nary_standard_impl(f, result, args...);
else else
nary_nonstandard(result, args...)(f); nary_nonstandard_impl(f, result, args...);
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
{
return [=](auto f) {
nary_nonstandard_impl(f, result, args...);
};
}
template <class... Arguments>
auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) {
nary_standard_impl(f, result, args...);
}; };
} }
template <class... Arguments> template <class... Arguments>
auto nary(argument result, Arguments... args) auto nary(argument result, Arguments... args)
{ {
return nary_impl(result, args...); return [=](auto f) {
nary_impl(f, result, args...);
};
} }
inline auto nary(const argument& result, const argument& arg1, const argument& arg2) inline auto nary(const argument& result, const argument& arg1, const argument& arg2)
...@@ -235,13 +239,13 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a ...@@ -235,13 +239,13 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0); (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
binary_broadcast_vec(result, arg1, arg2)(f); binary_broadcast_vec_impl(f, result, arg1, arg2);
else else
binary_broadcast(result, arg1, arg2)(f); binary_broadcast_impl(f, result, arg1, arg2);
return; return;
} }
} }
nary_impl(result, arg1, arg2)(f); nary_impl(f, result, arg1, arg2);
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment