Commit 07857fb4 authored by Paul's avatar Paul
Browse files

Merge

parents 1689d2d8 62e8ec20
...@@ -118,6 +118,12 @@ constexpr auto sequence_c_impl(F&& f, seq<Ns...>) ...@@ -118,6 +118,12 @@ constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
return f(index_constant<Ns>{}...); return f(index_constant<Ns>{}...);
} }
template <class F, index_int... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(integral_constant<index_int, Ns>{}), 0)...};
}
template <index_int... N> template <index_int... N>
constexpr auto args_at(seq<N...>) constexpr auto args_at(seq<N...>)
{ {
...@@ -144,6 +150,18 @@ constexpr auto sequence(IntegerConstant ic, F&& f) ...@@ -144,6 +150,18 @@ constexpr auto sequence(IntegerConstant ic, F&& f)
return sequence_c<ic>(f); return sequence_c<ic>(f);
} }
template <std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto repeat(IntegerConstant ic, F&& f)
{
return repeat_c<ic>(f);
}
template <class F, class G> template <class F, class G>
constexpr auto by(F f, G g) constexpr auto by(F f, G g)
{ {
......
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
...@@ -64,29 +66,61 @@ struct index ...@@ -64,29 +66,61 @@ struct index
{ {
return _c<1> + n / nlocal(); return _c<1> + n / nlocal();
} }
template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride)
{
return (n - _c<1>) / stride + _c<1>;
}
template <class F> template <class F, class N, class Stride>
__device__ void global_stride(index_int n, F f) const static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{
if constexpr(not is_integral<N>{} and not is_integral<Stride>{})
{
if constexpr(max_stride_iterations(n, stride) == 1)
{
if constexpr(stride > n)
{ {
const auto stride = nglobal(); if(start < n)
for(index_int i = global; i < n; i += stride) f(start);
}
else
{
f(start);
}
}
else
{ {
repeat(max_stride_iterations(n, stride), [&](auto k) {
auto i = start + stride * k;
if(i < n)
f(i); f(i);
});
} }
} }
else
template <class F>
__device__ void local_stride(index_int n, F f) const
{ {
const auto stride = nlocal(); for(index_int i = start; i < n; i += stride)
for(index_int i = local; i < n; i += stride)
{ {
f(i); f(i);
} }
} }
}
template <class F, class N>
__device__ void global_stride(N n, F f) const
{
for_stride(global, n, nglobal(), f);
}
template <class F, class N>
__device__ void local_stride(N n, F f) const
{
for_stride(local, n, nlocal(), f);
}
}; };
inline __device__ index make_index() inline __device__ __attribute__((const)) index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
} }
......
...@@ -186,6 +186,7 @@ __device__ auto auto_preload(index idx) ...@@ -186,6 +186,7 @@ __device__ auto auto_preload(index idx)
{ {
return make_transform([=](auto f, auto... xs) { return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) { auto invoke = [=](auto... ys) {
if constexpr((Bs or ...))
__syncthreads(); __syncthreads();
f(ys...); f(ys...);
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment