Unverified Commit bae1d7e2 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Change to stable sort in nms implementations (#4767)

* change to stable sort in nms implementations
parent 7fa267e8
......@@ -27,7 +27,8 @@ at::Tensor nms_kernel_impl(
at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
......
......@@ -109,7 +109,8 @@ at::Tensor nms_kernel(
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();
int dets_num = dets.size(0);
......
......@@ -27,7 +27,8 @@ at::Tensor qnms_kernel_impl(
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment