Unverified Commit ac3ba944 authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Enable autocast for NMS and ROIAlign on ROCm (#2637)

* add autocasting on ROCm

* enable ROIAlign autocasting on ROCm

* enable NMS autocasting on ROCm

* fix to use correct torch CUDA APIs
parent 02f46a57
......@@ -7,6 +7,7 @@
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif
......@@ -37,7 +38,7 @@ at::Tensor roi_align(
aligned);
}
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor ROIAlign_autocast(
const at::Tensor& input,
const at::Tensor& rois,
......
#pragma once
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)
namespace autocast {
inline bool is_eligible(const at::Tensor& arg) {
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#if defined(WITH_CUDA)
#include <c10/cuda/CUDAGuard.h>
#elif defined(WITH_HIP)
#include <c10/hip/HIPGuard.h>
#endif
#include "cuda_helpers.h"
......@@ -98,10 +93,8 @@ at::Tensor nms_cuda(const at::Tensor& dets,
" and ",
scores.size(0))
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::cuda::CUDAGuard device_guard(dets.device());
#elif defined(WITH_HIP)
at::cuda::HIPGuard device_guard(dets.device());
#else
AT_ERROR("Not compiled with GPU support");
#endif
......
......@@ -6,6 +6,7 @@
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif
......@@ -20,7 +21,7 @@ at::Tensor nms(
return op.call(dets, scores, iou_threshold);
}
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
......
......@@ -72,7 +72,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
#endif
// Autocast only needs to wrap forward pass ops.
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast);
m.impl("nms", nms_autocast);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment