Commit bf3b299e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug related to reduce

parent 78cb524d
......@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
std::size_t relements)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
......@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
const auto base_idx = out_idx * relements;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
......@@ -276,13 +275,13 @@ void reduce(hipStream_t stream,
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size());
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
std::size_t stride = input_shape.strides().at(input_shape.strides().size() - 2);
reduce_standard_impl(stream,
result,
arg,
......@@ -290,8 +289,7 @@ void reduce(hipStream_t stream,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
input_shape.lens().back());
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment