Commit bb666690 authored by Paul's avatar Paul
Browse files

Fix clang tidy complexity issue

parent 26a78750
...@@ -51,15 +51,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -51,15 +51,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
}); });
} }
template <class... Arguments> template<class F>
auto nary_nonstandard(argument result, Arguments... args) void binary_broadcast_vec_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{
return [=](auto f) { return nary_nonstandard_impl(f, result, args...); };
}
inline auto binary_broadcast_vec(const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim = std::distance(b_shape.strides().begin(),
...@@ -106,57 +100,54 @@ inline auto binary_broadcast_vec(const argument& result, const argument& arg1, c ...@@ -106,57 +100,54 @@ inline auto binary_broadcast_vec(const argument& result, const argument& arg1, c
} }
}); });
}); });
};
} }
inline auto binary_broadcast(const argument& result, const argument& arg1, const argument& arg2) template<class F>
void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) { const auto& output_shape = result.get_shape();
const auto& output_shape = result.get_shape(); const auto& b_shape = arg2.get_shape();
const auto& b_shape = arg2.get_shape(); auto bdim = std::distance(b_shape.strides().begin(),
auto bdim = std::distance(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(),
b_shape.strides().end(), [](auto x) { return x != 0; }));
[](auto x) { return x != 0; })); auto bdim_len = output_shape.lens()[bdim];
auto bdim_len = output_shape.lens()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_next_stride = bdim_stride * bdim_len;
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data(); auto* xp = input1.data();
auto* yp = input2.data(); auto* yp = input2.data();
auto* outp = output.data(); auto* outp = output.data();
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size(); const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED type buffer[2048]; MIGRAPH_DEVICE_SHARED type buffer[2048];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = yp[i];
} }
__syncthreads(); __syncthreads();
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx]; auto b = buffer[bidx];
type x = xp[i]; type x = xp[i];
outp[i] = f(x, b); outp[i] = f(x, b);
} }
});
}); });
}; });
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_standard_vec(argument result, Arguments... args) void nary_standard_vec_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
...@@ -177,13 +168,11 @@ auto nary_standard_vec(argument result, Arguments... args) ...@@ -177,13 +168,11 @@ auto nary_standard_vec(argument result, Arguments... args)
outp[i] = out; outp[i] = out;
}); });
}); });
};
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_standard(argument result, Arguments... args) void nary_standard_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
...@@ -192,29 +181,44 @@ auto nary_standard(argument result, Arguments... args) ...@@ -192,29 +181,44 @@ auto nary_standard(argument result, Arguments... args)
gs_launch(output_shape.elements())( gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
};
} }
template <class... Arguments> template <class F, class... Arguments>
auto nary_impl(argument result, Arguments... args) void nary_impl(F f, argument result, Arguments... args)
{ {
return [=](auto f) {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); }); bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes)) if(standard or (packed and same_shapes))
nary_standard(result, args...)(f); nary_standard_impl(f, result, args...);
else else
nary_nonstandard(result, args...)(f); nary_nonstandard_impl(f, result, args...);
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
{
return [=](auto f) {
nary_nonstandard_impl(f, result, args...);
};
}
template <class... Arguments>
auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) {
nary_standard_impl(f, result, args...);
}; };
} }
template <class... Arguments> template <class... Arguments>
auto nary(argument result, Arguments... args) auto nary(argument result, Arguments... args)
{ {
return nary_impl(result, args...); return [=](auto f) {
nary_impl(f, result, args...);
};
} }
inline auto nary(const argument& result, const argument& arg1, const argument& arg2) inline auto nary(const argument& result, const argument& arg1, const argument& arg2)
...@@ -235,13 +239,13 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a ...@@ -235,13 +239,13 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0); (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
binary_broadcast_vec(result, arg1, arg2)(f); binary_broadcast_vec_impl(f, result, arg1, arg2);
else else
binary_broadcast(result, arg1, arg2)(f); binary_broadcast_impl(f, result, arg1, arg2);
return; return;
} }
} }
nary_impl(result, arg1, arg2)(f); nary_impl(f, result, arg1, arg2);
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment