Commit 28727db2 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix Nonzero to track data value with sentinel value based on elements

We can't change the behaviour of the nonzero op and we currently pad the output
with zeros. This unfortunately obfuscates the following cases:

1. When the only nonzero element is the first index - the whole tensor is padded
with zeros its not obvious if the first value is valid index or padded

2. When the nonzero elements vector is used for indicies. The resulting vector
   with the padded value of 0 is still a valid index thus gather/gatherND and other ops
   will assume the 0 index is valid and operate accordingly.

In this case, by adding a sentinel value of the number of static elements used
by the desired shape, the resulting nonzero output can now track how many elements
are valid by determining the value in the correct range.

Originally I intended to use -1 but not all datatypes use this if say, we're dealing with
unsigned values in our vectors or booleans.
parent 84a8f450
...@@ -66,7 +66,7 @@ struct nonzero ...@@ -66,7 +66,7 @@ struct nonzero
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), output_shape.elements());
par_for(vec_idx.size(), [&](auto i) { par_for(vec_idx.size(), [&](auto i) {
for(std::size_t j = 0; j < vec_idx.front().size(); ++j) for(std::size_t j = 0; j < vec_idx.front().size(); ++j)
{ {
......
...@@ -45,8 +45,8 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg ...@@ -45,8 +45,8 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
const auto* in_ptr = device_cast(input.data()); const auto* in_ptr = device_cast(input.data());
auto* ptr = result.cast<int64_t>(); auto* ptr = result.cast<int64_t>();
gs_launch(stream, block_size, block_size)([=](auto, auto idx) __device__ { gs_launch(stream, block_size, block_size)([=](auto, auto idx) __device__ {
// fill all output to 0 first // fill all output to elem_num first
idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; }); idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = elem_num; });
block_scan<block_size>( block_scan<block_size>(
idx, idx,
......
...@@ -5283,8 +5283,8 @@ TEST_CASE(nonzero_test) ...@@ -5283,8 +5283,8 @@ TEST_CASE(nonzero_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int64_t> result_vector; std::vector<int64_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 36, 36, 36, 36, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0}; 1, 1, 36, 36, 36, 36, 0, 1, 0, 2, 0, 2, 0, 2, 36, 36, 36, 36};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment