Commit bf3b299e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug related to reduce

parent 78cb524d
...@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream,
T init, T init,
Input read_input, Input read_input,
Output read_output, Output read_output,
std::size_t relements, std::size_t relements)
std::size_t stride)
{ {
hip_visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
...@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride; const auto base_idx = out_idx * relements;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ { auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]); return read_input(input.data()[base_idx + j]);
}); });
...@@ -276,13 +275,13 @@ void reduce(hipStream_t stream, ...@@ -276,13 +275,13 @@ void reduce(hipStream_t stream,
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape(); auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size());
if(input_shape.standard() and output_shape.standard() and if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(), std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()), std::prev(output_shape.lens().end()),
input_shape.lens().begin())) input_shape.lens().begin()))
{ {
std::size_t stride = input_shape.strides().at(input_shape.strides().size() - 2);
reduce_standard_impl(stream, reduce_standard_impl(stream,
result, result,
arg, arg,
...@@ -290,8 +289,7 @@ void reduce(hipStream_t stream, ...@@ -290,8 +289,7 @@ void reduce(hipStream_t stream,
init, init,
read_input, read_input,
read_output, read_output,
input_shape.lens().back(), input_shape.lens().back());
stride);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment