Commit 6382ff10 authored by charlie's avatar charlie
Browse files

Dynamic output shape works

parent 39f5fe15
......@@ -68,7 +68,7 @@ struct nonmaxsuppression
MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input");
}
if(boxes_dims.at(0) != inputs.at(1).lens()[0])
if(boxes_dims.at(0) != scores_dims.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input");
......@@ -237,7 +237,9 @@ struct nonmaxsuppression
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// make buffer of maximum size
shape max_output_shape = {output_shape.type(), output_shape.max_lens()};
argument result{max_output_shape};
std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
......@@ -258,7 +260,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());
selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) {
......@@ -281,11 +283,11 @@ struct nonmaxsuppression
class_idx);
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
num_selected = selected_indices.size();
num_selected = selected_indices.size() / 3;
});
});
return result.reshape({num_selected, 3});
return result.reshape({output_shape.type(), {num_selected, 3}});
}
};
......
......@@ -3433,7 +3433,7 @@ TEST_CASE(nms_not_center_test)
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
......@@ -3466,7 +3466,7 @@ TEST_CASE(nms_test)
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
......@@ -3503,7 +3503,7 @@ TEST_CASE(nms_transpose1_test)
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
......@@ -3540,7 +3540,7 @@ TEST_CASE(nms_transpose2_test)
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment