Commit 6382ff10 authored by charlie's avatar charlie
Browse files

Dynamic output shape works

parent 39f5fe15
...@@ -68,7 +68,7 @@ struct nonmaxsuppression ...@@ -68,7 +68,7 @@ struct nonmaxsuppression
MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between " MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input"); "boxes and scores input");
} }
if(boxes_dims.at(0) != inputs.at(1).lens()[0]) if(boxes_dims.at(0) != scores_dims.at(0))
{ {
MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between " MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input"); "boxes and scores input");
...@@ -237,7 +237,9 @@ struct nonmaxsuppression ...@@ -237,7 +237,9 @@ struct nonmaxsuppression
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; // make buffer of maximum size
shape max_output_shape = {output_shape.type(), output_shape.max_lens()};
argument result{max_output_shape};
std::size_t max_output_boxes_per_class = std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
...@@ -258,7 +260,7 @@ struct nonmaxsuppression ...@@ -258,7 +260,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class; std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
...@@ -281,11 +283,11 @@ struct nonmaxsuppression ...@@ -281,11 +283,11 @@ struct nonmaxsuppression
class_idx); class_idx);
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
num_selected = selected_indices.size(); num_selected = selected_indices.size() / 3;
}); });
}); });
return result.reshape({num_selected, 3}); return result.reshape({output_shape.type(), {num_selected, 3}});
} }
}; };
......
...@@ -3433,7 +3433,7 @@ TEST_CASE(nms_not_center_test) ...@@ -3433,7 +3433,7 @@ TEST_CASE(nms_not_center_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
...@@ -3466,7 +3466,7 @@ TEST_CASE(nms_test) ...@@ -3466,7 +3466,7 @@ TEST_CASE(nms_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
...@@ -3503,7 +3503,7 @@ TEST_CASE(nms_transpose1_test) ...@@ -3503,7 +3503,7 @@ TEST_CASE(nms_transpose1_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
...@@ -3540,7 +3540,7 @@ TEST_CASE(nms_transpose2_test) ...@@ -3540,7 +3540,7 @@ TEST_CASE(nms_transpose2_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment