Commit ec5a90ae authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix two small bugs.

parent 25bac567
......@@ -80,7 +80,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
const std::size_t index = 2 * s * idx.local;
if(index < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
buffer[index + s] = op(buffer[index], buffer[index + s]);
}
__syncthreads();
}
......@@ -185,7 +185,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
type y = 0;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++)
{
y += buffer[i];
y = op(y, buffer[i]);
}
return y;
}
......@@ -225,7 +225,7 @@ void reduce(hipStream_t stream,
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
const std::size_t max_block_size = 1024;
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
......
......@@ -3450,7 +3450,7 @@ struct test_reduce_sum : verify_program<test_reduce_sum>
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 8, 8}};
migraphx::shape s{migraphx::shape::float_type, {3, 1026, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment