Unverified Commit bae1d7e2 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Change to stable sort in nms implementations (#4767)

* change to stable sort in nms implementations
parent 7fa267e8
...@@ -27,7 +27,8 @@ at::Tensor nms_kernel_impl( ...@@ -27,7 +27,8 @@ at::Tensor nms_kernel_impl(
at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto ndets = dets.size(0); auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
......
...@@ -109,7 +109,8 @@ at::Tensor nms_kernel( ...@@ -109,7 +109,8 @@ at::Tensor nms_kernel(
return at::empty({0}, dets.options().dtype(at::kLong)); return at::empty({0}, dets.options().dtype(at::kLong));
} }
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous(); auto dets_sorted = dets.index_select(0, order_t).contiguous();
int dets_num = dets.size(0); int dets_num = dets.size(0);
......
...@@ -27,7 +27,8 @@ at::Tensor qnms_kernel_impl( ...@@ -27,7 +27,8 @@ at::Tensor qnms_kernel_impl(
auto y1_t = dets.select(1, 1).contiguous(); auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous(); auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous(); auto y2_t = dets.select(1, 3).contiguous();
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat)); at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment