Commit 803cbb7a authored by charlie's avatar charlie
Browse files

first draft

parent 6efffa37
...@@ -98,7 +98,7 @@ struct check_shapes ...@@ -98,7 +98,7 @@ struct check_shapes
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() != n) if(begin->max_lens().size() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -110,7 +110,7 @@ struct check_shapes ...@@ -110,7 +110,7 @@ struct check_shapes
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() > n) if(begin->max_lens().size() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions"); " dimensions");
} }
......
...@@ -57,25 +57,42 @@ struct nonmaxsuppression ...@@ -57,25 +57,42 @@ struct nonmaxsuppression
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
// both boxes and scores will be dynamic if one of them is dynamic
if(inputs.at(0).dynamic())
{
// check dynamic dimensions are consistent
const auto boxes_dims = inputs.at(0).dyn_dims();
const auto scores_dims = inputs.at(1).dyn_dims();
if(boxes_dims.at(1) != scores_dims.at(2))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input");
}
if(boxes_dims.at(0) != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input");
}
}
else
{
auto lens = inputs.front().lens(); auto lens = inputs.front().lens();
// check input shape
if(lens[1] != inputs.at(1).lens()[2]) if(lens[1] != inputs.at(1).lens()[2])
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"); "NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
} }
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0]) if(lens[0] != inputs.at(1).lens()[0])
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"); "NonMaxSuppression: number of batches mismatch between boxes and scores input");
} }
}
std::vector<int64_t> out_lens(2); std::vector<shape::dynamic_dimension> out_lens = {};
out_lens.at(0) = lens.at(1); out_lens.push_back({0, inputs.at(0).max_lens().at(1), 0});
out_lens.at(1) = 3; out_lens.push_back({3, 3, 0});
return {shape::int64_type, out_lens}; return {shape::int64_type, out_lens};
} }
...@@ -230,10 +247,10 @@ struct nonmaxsuppression ...@@ -230,10 +247,10 @@ struct nonmaxsuppression
} }
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f; double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f; double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
std::size_t num_selected = 0;
result.visit([&](auto output) { result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) { visit_all(args[0], args[1])([&](auto boxes, auto scores) {
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0]; const auto num_batches = lens[0];
const auto num_classes = lens[1]; const auto num_classes = lens[1];
...@@ -264,10 +281,11 @@ struct nonmaxsuppression ...@@ -264,10 +281,11 @@ struct nonmaxsuppression
class_idx); class_idx);
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
num_selected = selected_indices.size();
}); });
}); });
return result; return result.reshape({num_selected, 3});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment