Unverified Commit ea44da50 authored by ndickson-nvidia's avatar ndickson-nvidia Committed by GitHub
Browse files

[Bug] Added common operations for FP16 on older GPUs (#4079)

* * Added support for common operations on FP16 (`half` or `__half`) for older GPU architectures
* Fixed an issue with previous check for FP16 support

* * Removing FP16 type checks, since they should no longer be needed

* * Fixed AtomicAdd to be atomic for `float` and `double` for old GPU architectures.  Unfortunately, it seems that atomicCAS for unsigned short seems to be unavailable until architecture 70, so half will have to stay non-atomic on old GPUs.

* * Fixed non-atomic version of `AtomicAdd<half>` for older GPUs to return old value instead value of new
parent 31a81438
...@@ -219,7 +219,17 @@ __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) { ...@@ -219,7 +219,17 @@ __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
#if __CUDA_ARCH__ >= 200 #if __CUDA_ARCH__ >= 200
return atomicAdd(addr, val); return atomicAdd(addr, val);
#else #else
return *addr + val; typedef float T;
typedef typename Cast<T>::Type CT;
CT* addr_as_ui = reinterpret_cast<CT*>(addr);
CT old = *addr_as_ui;
CT assumed = old;
do {
assumed = old;
old = atomicCAS(addr_as_ui, assumed,
Cast<T>::Encode(Cast<T>::Decode(old) + val));
} while (assumed != old);
return Cast<T>::Decode(old);
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
} }
...@@ -228,23 +238,33 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) { ...@@ -228,23 +238,33 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
#if __CUDA_ARCH__ >= 600 #if __CUDA_ARCH__ >= 600
return atomicAdd(addr, val); return atomicAdd(addr, val);
#else #else
return *addr + val; typedef double T;
typedef typename Cast<T>::Type CT;
CT* addr_as_ui = reinterpret_cast<CT*>(addr);
CT old = *addr_as_ui;
CT assumed = old;
do {
assumed = old;
old = atomicCAS(addr_as_ui, assumed,
Cast<T>::Encode(Cast<T>::Decode(old) + val));
} while (assumed != old);
return Cast<T>::Decode(old);
#endif #endif
} }
#ifdef USE_FP16 #ifdef USE_FP16
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
// half make sure we have half support
#if __CUDA_ARCH__ >= 600
template <> template <>
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) { __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
// half make sure we have half support
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
return atomicAdd(addr, val); return atomicAdd(addr, val);
#else #else
return *addr + val; half old = *addr;
*addr = half(float(old) + float(val));
return old;
#endif // __CUDA_ARCH__ >= 700 #endif // __CUDA_ARCH__ >= 700
} }
#endif // __CUDA_ARCH__ >= 600
#endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif // USE_FP16 #endif // USE_FP16
......
...@@ -17,7 +17,7 @@ static __device__ __forceinline__ half max(half a, half b) ...@@ -17,7 +17,7 @@ static __device__ __forceinline__ half max(half a, half b)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(__half(a), __half(b)) ? a : b; return __hgt(__half(a), __half(b)) ? a : b;
#else #else
return a; return __half(max(float(a), float(b)));
#endif #endif
} }
...@@ -26,9 +26,40 @@ static __device__ __forceinline__ half min(half a, half b) ...@@ -26,9 +26,40 @@ static __device__ __forceinline__ half min(half a, half b)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(__half(a), __half(b)) ? a : b; return __hlt(__half(a), __half(b)) ? a : b;
#else #else
return a; return __half(min(float(a), float(b)));
#endif #endif
} }
#ifdef __CUDACC__
// Arithmetic FP16 operations for architecture >= 5.3 are already defined in cuda_fp16.h
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)
__device__ __forceinline__ __half operator+(const __half& lh, const __half& rh) { return __half(float(lh) + float(rh)); }
__device__ __forceinline__ __half operator-(const __half& lh, const __half& rh) { return __half(float(lh) - float(rh)); }
__device__ __forceinline__ __half operator*(const __half& lh, const __half& rh) { return __half(float(lh) * float(rh)); }
__device__ __forceinline__ __half operator/(const __half& lh, const __half& rh) { return __half(float(lh) / float(rh)); }
__device__ __forceinline__ __half& operator+=(__half& lh, const __half& rh) { lh = __half(float(lh) + float(rh)); return lh; }
__device__ __forceinline__ __half& operator-=(__half& lh, const __half& rh) { lh = __half(float(lh) - float(rh)); return lh; }
__device__ __forceinline__ __half& operator*=(__half& lh, const __half& rh) { lh = __half(float(lh) * float(rh)); return lh; }
__device__ __forceinline__ __half& operator/=(__half& lh, const __half& rh) { lh = __half(float(lh) / float(rh)); return lh; }
__device__ __forceinline__ __half& operator++(__half& h) { h = __half(float(h) + 1.0f); return h; }
__device__ __forceinline__ __half& operator--(__half& h) { h = __half(float(h) - 1.0f); return h; }
__device__ __forceinline__ __half operator++(__half& h, int) { __half ret = h; h = __half(float(h) + 1.0f); return ret; }
__device__ __forceinline__ __half operator--(__half& h, int) { __half ret = h; h = __half(float(h) - 1.0f); return ret; }
__device__ __forceinline__ __half operator+(const __half& h) { return h; }
__device__ __forceinline__ __half operator-(const __half& h) { return __half(-float(h)); }
__device__ __forceinline__ bool operator==(const __half& lh, const __half& rh) { return float(lh) == float(rh); }
__device__ __forceinline__ bool operator!=(const __half& lh, const __half& rh) { return float(lh) != float(rh); }
__device__ __forceinline__ bool operator> (const __half& lh, const __half& rh) { return float(lh) > float(rh); }
__device__ __forceinline__ bool operator< (const __half& lh, const __half& rh) { return float(lh) < float(rh); }
__device__ __forceinline__ bool operator>=(const __half& lh, const __half& rh) { return float(lh) >= float(rh); }
__device__ __forceinline__ bool operator<=(const __half& lh, const __half& rh) { return float(lh) <= float(rh); }
#endif // __CUDA_ARCH__ < 530
#endif // __CUDACC__
#endif // USE_FP16 #endif // USE_FP16
#endif // DGL_ARRAY_FP16_CUH_ #endif // DGL_ARRAY_FP16_CUH_
...@@ -21,7 +21,6 @@ namespace cuda { ...@@ -21,7 +21,6 @@ namespace cuda {
#define CUDA_MAX_NUM_THREADS 1024 #define CUDA_MAX_NUM_THREADS 1024
#ifdef USE_FP16 #ifdef USE_FP16
#if __CUDA_ARCH__ >= 600
#define SWITCH_BITS(bits, DType, ...) \ #define SWITCH_BITS(bits, DType, ...) \
do { \ do { \
if ((bits) == 16) { \ if ((bits) == 16) { \
...@@ -37,22 +36,6 @@ namespace cuda { ...@@ -37,22 +36,6 @@ namespace cuda {
LOG(FATAL) << "Data type not recognized with bits " << bits; \ LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \ } \
} while (0) } while (0)
#else
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16) { \
LOG(FATAL) << "FP16 only supported on CUDA architectures >= 60"; \
} else if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
#endif // __CUDA_ARCH__ >= 600
#else // USE_FP16 #else // USE_FP16
#define SWITCH_BITS(bits, DType, ...) \ #define SWITCH_BITS(bits, DType, ...) \
do { \ do { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment