Commit 6bc2d0e3 authored by Paul's avatar Paul
Browse files

Formatting

parent 3809fcb4
...@@ -15,17 +15,11 @@ struct index ...@@ -15,17 +15,11 @@ struct index
std::size_t local; std::size_t local;
std::size_t group; std::size_t group;
__device__ std::size_t nglobal() const __device__ std::size_t nglobal() const { return blockDim.x * gridDim.x; }
{
return blockDim.x * gridDim.x;
}
__device__ std::size_t nlocal() const __device__ std::size_t nlocal() const { return blockDim.x; }
{
return blockDim.x;
}
template<class F> template <class F>
__device__ void global_stride(std::size_t n, F f) const __device__ void global_stride(std::size_t n, F f) const
{ {
const auto stride = nglobal(); const auto stride = nglobal();
...@@ -35,7 +29,7 @@ struct index ...@@ -35,7 +29,7 @@ struct index
} }
} }
template<class F> template <class F>
__device__ void local_stride(std::size_t n, F f) const __device__ void local_stride(std::size_t n, F f) const
{ {
const auto stride = nlocal(); const auto stride = nlocal();
...@@ -83,11 +77,8 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102 ...@@ -83,11 +77,8 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) { launch(stream, nglobal, local)(
idx.global_stride(n, [&](auto i) { [=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); });
gs_invoke(f, i, idx);
});
});
}; };
} }
......
...@@ -22,9 +22,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -22,9 +22,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
x = op(x, f(i));
});
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment