"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7635d3d37fa458d57d90deec364c8829b4ba592b"
Unverified Commit 7a065a9c authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Build][Tests] Enable FP16 for GPU builds in CI (#4030)

* Enable FP16 for GPU builds in CI

* Limit default GPU archs to pascal and above

* Disable FP16 dispatching for cuda architectures less than 60

* Fix linting

* Fix typos
parent d1124b7b
...@@ -234,14 +234,17 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) { ...@@ -234,14 +234,17 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
#ifdef USE_FP16 #ifdef USE_FP16
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
// half make sure we have half support
#if __CUDA_ARCH__ >= 600
template <> template <>
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) { __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
return atomicAdd(addr, val); return atomicAdd(addr, val);
#else #else
return *addr + val; return *addr + val;
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__ >= 700
} }
#endif // __CUDA_ARCH__ >= 600
#endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif // USE_FP16 #endif // USE_FP16
......
...@@ -21,6 +21,7 @@ namespace cuda { ...@@ -21,6 +21,7 @@ namespace cuda {
#define CUDA_MAX_NUM_THREADS 1024 #define CUDA_MAX_NUM_THREADS 1024
#ifdef USE_FP16 #ifdef USE_FP16
#if __CUDA_ARCH__ >= 600
#define SWITCH_BITS(bits, DType, ...) \ #define SWITCH_BITS(bits, DType, ...) \
do { \ do { \
if ((bits) == 16) { \ if ((bits) == 16) { \
...@@ -36,6 +37,22 @@ namespace cuda { ...@@ -36,6 +37,22 @@ namespace cuda {
LOG(FATAL) << "Data type not recognized with bits " << bits; \ LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \ } \
} while (0) } while (0)
#else
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16) { \
LOG(FATAL) << "FP16 only supported on CUDA architectures >= 60"; \
} else if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
#endif // __CUDA_ARCH__ >= 600
#else // USE_FP16 #else // USE_FP16
#define SWITCH_BITS(bits, DType, ...) \ #define SWITCH_BITS(bits, DType, ...) \
do { \ do { \
......
...@@ -20,7 +20,7 @@ if [[ $arch == *"x86"* ]]; then ...@@ -20,7 +20,7 @@ if [[ $arch == *"x86"* ]]; then
fi fi
if [ "$1" == "gpu" ]; then if [ "$1" == "gpu" ]; then
CMAKE_VARS="-DUSE_CUDA=ON -DUSE_NCCL=ON $CMAKE_VARS" CMAKE_VARS="-DUSE_CUDA=ON -DUSE_NCCL=ON -DUSE_FP16=ON $CMAKE_VARS"
fi fi
if [ -d build ]; then if [ -d build ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment