"vscode:/vscode.git/clone" did not exist on "7dd6878820df93ffbf17d37d942bd940dde5827a"
Unverified Commit 455cd57c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

NMS code cleanup (#2907)

* Clean up and refactor ROIAlign implementation:
- Remove primitive const declaration from method names.
- Remove unnecessary headers.
- Aligning method names between cpu and cuda.

* Adding back include for cpu.

* Restoring method names of private methods to avoid conflicts.

* Restore include headers.
parent c9d9e67e
......@@ -4,7 +4,7 @@ template <typename scalar_t>
at::Tensor nms_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK(
......@@ -72,7 +72,7 @@ at::Tensor nms_cpu_kernel(
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
......
......@@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cpu(
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
......
......@@ -22,8 +22,8 @@ __device__ inline bool devIoU(T const* const a, T const* const b, const float th
template <typename T>
__global__ void nms_kernel(
const int n_boxes,
const float iou_threshold,
int n_boxes,
double iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
......@@ -70,7 +70,7 @@ __global__ void nms_kernel(
at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor");
TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor");
......
......@@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cuda(
VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
......
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
......@@ -14,7 +15,7 @@
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::nms", "")
.typed<decltype(nms)>();
......@@ -25,7 +26,7 @@ at::Tensor nms(
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return nms(
at::autocast::cached_cast(at::kFloat, dets),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment