Unverified Commit 455cd57c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

NMS code cleanup (#2907)

* Clean up and refactor ROIAlign implementation:
- Remove primitive const declaration from method names.
- Remove unnecessary headers.
- Aligning method names between cpu and cuda.

* Adding back include for cpu.

* Restoring method names of private methods to avoid conflicts.

* Restore include headers.
parent c9d9e67e
...@@ -4,7 +4,7 @@ template <typename scalar_t> ...@@ -4,7 +4,7 @@ template <typename scalar_t>
at::Tensor nms_cpu_kernel( at::Tensor nms_cpu_kernel(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold) { double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK( TORCH_CHECK(
...@@ -72,7 +72,7 @@ at::Tensor nms_cpu_kernel( ...@@ -72,7 +72,7 @@ at::Tensor nms_cpu_kernel(
at::Tensor nms_cpu( at::Tensor nms_cpu(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold) { double iou_threshold) {
TORCH_CHECK( TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK( TORCH_CHECK(
......
...@@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cpu( ...@@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cpu(
VISION_API at::Tensor nms_cpu( VISION_API at::Tensor nms_cpu(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold); double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cpu( VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
......
...@@ -22,8 +22,8 @@ __device__ inline bool devIoU(T const* const a, T const* const b, const float th ...@@ -22,8 +22,8 @@ __device__ inline bool devIoU(T const* const a, T const* const b, const float th
template <typename T> template <typename T>
__global__ void nms_kernel( __global__ void nms_kernel(
const int n_boxes, int n_boxes,
const float iou_threshold, double iou_threshold,
const T* dev_boxes, const T* dev_boxes,
unsigned long long* dev_mask) { unsigned long long* dev_mask) {
const int row_start = blockIdx.y; const int row_start = blockIdx.y;
...@@ -70,7 +70,7 @@ __global__ void nms_kernel( ...@@ -70,7 +70,7 @@ __global__ void nms_kernel(
at::Tensor nms_cuda(const at::Tensor& dets, at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold) { double iou_threshold) {
TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor");
TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor");
......
...@@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cuda( ...@@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cuda(
VISION_API at::Tensor nms_cuda( VISION_API at::Tensor nms_cuda(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold); double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cuda( VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
......
#pragma once #pragma once
#include "cpu/vision_cpu.h" #include "cpu/vision_cpu.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
...@@ -14,7 +15,7 @@ ...@@ -14,7 +15,7 @@
at::Tensor nms( at::Tensor nms(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold) { double iou_threshold) {
static auto op = c10::Dispatcher::singleton() static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::nms", "") .findSchemaOrThrow("torchvision::nms", "")
.typed<decltype(nms)>(); .typed<decltype(nms)>();
...@@ -25,7 +26,7 @@ at::Tensor nms( ...@@ -25,7 +26,7 @@ at::Tensor nms(
at::Tensor nms_autocast( at::Tensor nms_autocast(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const double iou_threshold) { double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return nms( return nms(
at::autocast::cached_cast(at::kFloat, dets), at::autocast::cached_cast(at::kFloat, dets),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment