Commit 5f84ce68 authored by Paul's avatar Paul
Browse files

Unroll strided loop

parent f2c9c70b
...@@ -118,6 +118,12 @@ constexpr auto sequence_c_impl(F&& f, seq<Ns...>) ...@@ -118,6 +118,12 @@ constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
return f(index_constant<Ns>{}...); return f(index_constant<Ns>{}...);
} }
template <class F, index_int... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(integral_constant<index_int, Ns>{}), 0)...};
}
template <index_int... N> template <index_int... N>
constexpr auto args_at(seq<N...>) constexpr auto args_at(seq<N...>)
{ {
...@@ -144,6 +150,18 @@ constexpr auto sequence(IntegerConstant ic, F&& f) ...@@ -144,6 +150,18 @@ constexpr auto sequence(IntegerConstant ic, F&& f)
return sequence_c<ic>(f); return sequence_c<ic>(f);
} }
template <std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto repeat(IntegerConstant ic, F&& f)
{
return repeat_c<ic>(f);
}
template <class F, class G> template <class F, class G>
constexpr auto by(F f, G g) constexpr auto by(F f, G g)
{ {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
...@@ -63,8 +64,9 @@ struct index ...@@ -63,8 +64,9 @@ struct index
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and if constexpr(not is_integral<N>{} and not is_integral<Stride>{})
max_stride_iterations(n, stride) == 1) {
if constexpr(max_stride_iterations(n, stride) == 1)
{ {
if constexpr(stride > n) if constexpr(stride > n)
{ {
...@@ -77,6 +79,13 @@ struct index ...@@ -77,6 +79,13 @@ struct index
} }
} }
else else
{
repeat(max_stride_iterations(n, stride), [&](auto i) {
f(start + stride*i);
});
}
}
else
{ {
for(index_int i = start; i < n; i += stride) for(index_int i = start; i < n; i += stride)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment