Unverified Commit c8a9ae75 authored by Zachary Streeter's avatar Zachary Streeter Committed by GitHub
Browse files

[Fix] Using PyTorch WARP_SHFL_DOWN macro for half support (#2843)

parent 6e9ee267
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#ifndef CARAFE_CUDA_KERNEL_CUH #ifndef CARAFE_CUDA_KERNEL_CUH
#define CARAFE_CUDA_KERNEL_CUH #define CARAFE_CUDA_KERNEL_CUH
#include <ATen/cuda/DeviceUtils.cuh>
#ifdef MMCV_USE_PARROTS #ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp" #include "parrots_cuda_helper.hpp"
#else #else
...@@ -56,7 +58,8 @@ template <> ...@@ -56,7 +58,8 @@ template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) { __device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP #ifdef MMCV_WITH_HIP
__PHALF(val) += __shfl_down(val, offset); // Using PyTorch's macro for half support
__PHALF(val) += WARP_SHFL_DOWN(val, offset);
#else #else
__PHALF(val) += __PHALF(val) +=
__shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset); __shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment