"test/verify/batch_quant_dot_4.cpp" did not exist on "8d21fdc9dd58e62192d9408132585eea94bbf79b"
Unverified Commit 3ab91a79 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #340 from ROCmSoftwarePlatform/reduce-mean-fix

Fix incorrect stride calculation in reduce_mean
parents 083d7a99 8be676c9
......@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
std::size_t relements)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
......@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
const auto base_idx = out_idx * relements;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
......@@ -276,25 +275,15 @@ void reduce(hipStream_t stream,
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size());
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
std::size_t stride = std::accumulate(input_shape.strides().begin(),
input_shape.strides().end(),
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
reduce_standard_impl(
stream, result, arg, op, init, read_input, read_output, input_shape.lens().back());
}
else
{
......
......@@ -3792,6 +3792,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
};
};
struct test_reduce_mean2 : verify_program<test_reduce_mean2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 128, 768}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment