Commit ec5a90ae authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix two small bugs.

parent 25bac567
...@@ -80,7 +80,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -80,7 +80,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
const std::size_t index = 2 * s * idx.local; const std::size_t index = 2 * s * idx.local;
if(index < idx.nlocal()) if(index < idx.nlocal())
{ {
buffer[index] = op(buffer[index], buffer[index + s]); buffer[index + s] = op(buffer[index], buffer[index + s]);
} }
__syncthreads(); __syncthreads();
} }
...@@ -185,7 +185,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -185,7 +185,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
type y = 0; type y = 0;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++) for(std::size_t i = 0; i < idx.nlocal() / 64; i++)
{ {
y += buffer[i]; y = op(y, buffer[i]);
} }
return y; return y;
} }
...@@ -225,7 +225,7 @@ void reduce(hipStream_t stream, ...@@ -225,7 +225,7 @@ void reduce(hipStream_t stream,
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
const std::size_t max_block_size = 1024; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
......
...@@ -3450,7 +3450,7 @@ struct test_reduce_sum : verify_program<test_reduce_sum> ...@@ -3450,7 +3450,7 @@ struct test_reduce_sum : verify_program<test_reduce_sum>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 8, 8}}; migraphx::shape s{migraphx::shape::float_type, {3, 1026, 4, 3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x); p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment