Unverified Commit ac3ba944 authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Enable autocast for NMS and ROIAlign on ROCm (#2637)

* add autocasting on ROCm

* enable ROIAlign autocasting on ROCm

* enable NMS autocasting on ROCm

* fix to use correct torch CUDA APIs
parent 02f46a57
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "cuda/vision_cuda.h" #include "cuda/vision_cuda.h"
#endif #endif
#ifdef WITH_HIP #ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h" #include "hip/vision_cuda.h"
#endif #endif
...@@ -37,7 +38,7 @@ at::Tensor roi_align( ...@@ -37,7 +38,7 @@ at::Tensor roi_align(
aligned); aligned);
} }
#ifdef WITH_CUDA #if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor ROIAlign_autocast( at::Tensor ROIAlign_autocast(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
......
#pragma once #pragma once
#ifdef WITH_CUDA #if defined(WITH_CUDA) || defined(WITH_HIP)
namespace autocast { namespace autocast {
inline bool is_eligible(const at::Tensor& arg) { inline bool is_eligible(const at::Tensor& arg) {
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#if defined(WITH_CUDA)
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#elif defined(WITH_HIP)
#include <c10/hip/HIPGuard.h>
#endif
#include "cuda_helpers.h" #include "cuda_helpers.h"
...@@ -98,10 +93,8 @@ at::Tensor nms_cuda(const at::Tensor& dets, ...@@ -98,10 +93,8 @@ at::Tensor nms_cuda(const at::Tensor& dets,
" and ", " and ",
scores.size(0)) scores.size(0))
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_HIP)
at::cuda::CUDAGuard device_guard(dets.device()); at::cuda::CUDAGuard device_guard(dets.device());
#elif defined(WITH_HIP)
at::cuda::HIPGuard device_guard(dets.device());
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");
#endif #endif
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "cuda/vision_cuda.h" #include "cuda/vision_cuda.h"
#endif #endif
#ifdef WITH_HIP #ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h" #include "hip/vision_cuda.h"
#endif #endif
...@@ -20,7 +21,7 @@ at::Tensor nms( ...@@ -20,7 +21,7 @@ at::Tensor nms(
return op.call(dets, scores, iou_threshold); return op.call(dets, scores, iou_threshold);
} }
#ifdef WITH_CUDA #if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor nms_autocast( at::Tensor nms_autocast(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
......
...@@ -72,7 +72,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { ...@@ -72,7 +72,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
#endif #endif
// Autocast only needs to wrap forward pass ops. // Autocast only needs to wrap forward pass ops.
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast); m.impl("roi_align", ROIAlign_autocast);
m.impl("nms", nms_autocast); m.impl("nms", nms_autocast);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment