Commit 6bc2d0e3 authored by Paul's avatar Paul
Browse files

Formatting

parent 3809fcb4
...@@ -15,17 +15,11 @@ struct index ...@@ -15,17 +15,11 @@ struct index
std::size_t local; std::size_t local;
std::size_t group; std::size_t group;
__device__ std::size_t nglobal() const __device__ std::size_t nglobal() const { return blockDim.x * gridDim.x; }
{
return blockDim.x * gridDim.x;
}
__device__ std::size_t nlocal() const __device__ std::size_t nlocal() const { return blockDim.x; }
{
return blockDim.x;
}
template<class F> template <class F>
__device__ void global_stride(std::size_t n, F f) const __device__ void global_stride(std::size_t n, F f) const
{ {
const auto stride = nglobal(); const auto stride = nglobal();
...@@ -35,7 +29,7 @@ struct index ...@@ -35,7 +29,7 @@ struct index
} }
} }
template<class F> template <class F>
__device__ void local_stride(std::size_t n, F f) const __device__ void local_stride(std::size_t n, F f) const
{ {
const auto stride = nlocal(); const auto stride = nlocal();
...@@ -83,11 +77,8 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102 ...@@ -83,11 +77,8 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) { launch(stream, nglobal, local)(
idx.global_stride(n, [&](auto i) { [=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); });
gs_invoke(f, i, idx);
});
});
}; };
} }
......
...@@ -22,9 +22,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -22,9 +22,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
x = op(x, f(i));
});
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
...@@ -42,7 +40,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -42,7 +40,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{ {
size_t block_size = 1; size_t block_size = 1;
while(block_size < max_block_size and block_size < n) while(block_size < max_block_size and block_size < n)
block_size *= 2; block_size *= 2;
return block_size; return block_size;
...@@ -69,7 +67,7 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) ...@@ -69,7 +67,7 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
const std::size_t max_block_size = 1024; const std::size_t max_block_size = 1024;
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i / block_size); auto base_idx = output.get_shape().multi(i / block_size);
auto offset = input.get_shape().index(base_idx); auto offset = input.get_shape().index(base_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment