Commit 506a73ec authored by Paul's avatar Paul
Browse files

Reduce number of index calculations

parent 6bc2d0e3
......@@ -73,6 +73,22 @@ struct hip_shape
}
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR hip_index carry(hip_index result) const
{
std::ptrdiff_t rem = 0;
for(std::ptrdiff_t i = result.size()-1; i >= 0; i--)
{
auto z = result[i] + rem;
rem = z - std::ptrdiff_t(lens[i]) + 1;
if (rem > 0)
z -= rem;
else
rem = 0;
result[i] = z;
}
return result;
}
};
template <std::size_t N>
......
......@@ -61,7 +61,7 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens, input_shape.strides()};
shape reduce_slice{output_shape.type(), reduce_lens};
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
......@@ -69,13 +69,14 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
const std::size_t max_block_size = 1024;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i / block_size);
auto offset = input.get_shape().index(base_idx);
const auto out_idx = i / block_size;
auto base_idx = output.get_shape().multi(out_idx);
auto r = block_reduce<max_block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ {
return input.data()[reduce_shape.index(j) + offset];
auto reduce_idx = reduce_shape.multi(j);
return input[reduce_idx + base_idx];
});
if(idx.local == 0)
output.data()[i / block_size] = r;
output.data()[out_idx] = r;
});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment