Commit 23c97fa9 authored by Paul's avatar Paul
Browse files

Improve loops

parent 7ddeb944
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
...@@ -53,26 +54,47 @@ struct index ...@@ -53,26 +54,47 @@ struct index
return blockDim.x; // NOLINT return blockDim.x; // NOLINT
} }
#endif #endif
template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride)
{
return (n - _c<1>) / stride + _c<1>;
}
template <class F> template <class F, class N, class Stride>
__device__ void global_stride(index_int n, F f) const static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
const auto stride = nglobal(); if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and max_stride_iterations(n, stride) == 1)
for(index_int i = global; i < n; i += stride)
{ {
f(i); if constexpr(stride > n)
{
if (start < n)
f(start);
} }
else
{
f(start);
} }
}
template <class F> else
__device__ void local_stride(index_int n, F f) const
{ {
const auto stride = nlocal(); for(index_int i = start; i < n; i += stride)
for(index_int i = local; i < n; i += stride)
{ {
f(i); f(i);
} }
} }
}
template <class F, class N>
__device__ void global_stride(N n, F f) const
{
for_stride(global, n, nglobal(), f);
}
template <class F, class N>
__device__ void local_stride(N n, F f) const
{
for_stride(local, n, nlocal(), f);
}
}; };
inline __device__ index make_index() inline __device__ index make_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment