Commit 5f84ce68 authored by Paul's avatar Paul
Browse files

Unroll strided loop

parent f2c9c70b
......@@ -118,6 +118,12 @@ constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
return f(index_constant<Ns>{}...);
}
template <class F, index_int... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(integral_constant<index_int, Ns>{}), 0)...};
}
template <index_int... N>
constexpr auto args_at(seq<N...>)
{
......@@ -144,6 +150,18 @@ constexpr auto sequence(IntegerConstant ic, F&& f)
return sequence_c<ic>(f);
}
template <std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto repeat(IntegerConstant ic, F&& f)
{
return repeat_c<ic>(f);
}
template <class F, class G>
constexpr auto by(F f, G g)
{
......
......@@ -27,6 +27,7 @@
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
......@@ -63,17 +64,25 @@ struct index
template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1)
if constexpr(not is_integral<N>{} and not is_integral<Stride>{})
{
if constexpr(stride > n)
if constexpr(max_stride_iterations(n, stride) == 1)
{
if(start < n)
if constexpr(stride > n)
{
if(start < n)
f(start);
}
else
{
f(start);
}
}
else
{
f(start);
repeat(max_stride_iterations(n, stride), [&](auto i) {
f(start + stride*i);
});
}
}
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment