/** * Copyright (c) 2022 by Contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * @file array/cuda/bf16.cuh * @brief bfloat16 related functions. */ #ifndef DGL_ARRAY_CUDA_BF16_CUH_ #define DGL_ARRAY_CUDA_BF16_CUH_ #if BF16_ENABLED #include #include static __device__ __forceinline__ __nv_bfloat16 max(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hmax(a, b); #else return __nv_bfloat16(max(float(a), float(b))); // NOLINT #endif } static __device__ __forceinline__ __nv_bfloat16 min(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hmin(a, b); #else return __nv_bfloat16(min(float(a), float(b))); // NOLINT #endif } #ifdef __CUDACC__ // Arithmetic BF16 operations for architecture >= 8.0 are already defined in // cuda_bf16.h #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) __device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return __nv_bfloat16(float(lh) + float(rh)); // NOLINT } __device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return __nv_bfloat16(float(lh) - float(rh)); // NOLINT } __device__ __forceinline__ __nv_bfloat16 operator*(const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return __nv_bfloat16(float(lh) * float(rh)); // NOLINT } __device__ __forceinline__ __nv_bfloat16 operator/(const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return __nv_bfloat16(float(lh) / float(rh)); // NOLINT } __device__ __forceinline__ __nv_bfloat16& operator+=( __nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT lh = __nv_bfloat16(float(lh) + float(rh)); // NOLINT return lh; } __device__ __forceinline__ __nv_bfloat16& operator-=( __nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT lh = __nv_bfloat16(float(lh) - float(rh)); // NOLINT return lh; } __device__ __forceinline__ __nv_bfloat16& operator*=( __nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT lh = __nv_bfloat16(float(lh) * float(rh)); // NOLINT return lh; } __device__ __forceinline__ __nv_bfloat16& operator/=( __nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT lh = __nv_bfloat16(float(lh) / float(rh)); // NOLINT return lh; } __device__ __forceinline__ __nv_bfloat16& operator++( __nv_bfloat16& h) { // NOLINT h = __nv_bfloat16(float(h) + 1.0f); // NOLINT return h; } __device__ __forceinline__ __nv_bfloat16& operator--( __nv_bfloat16& h) { // NOLINT h = __nv_bfloat16(float(h) - 1.0f); // NOLINT return h; } __device__ __forceinline__ __nv_bfloat16 operator++(__nv_bfloat16& h, int) { // NOLINT __nv_bfloat16 ret = h; h = __nv_bfloat16(float(h) + 1.0f); // NOLINT return ret; } __device__ __forceinline__ __nv_bfloat16 operator--(__nv_bfloat16& h, int) { // NOLINT __nv_bfloat16 ret = h; h = __nv_bfloat16(float(h) - 1.0f); // NOLINT return ret; } __device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) { return h; } __device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) { return __nv_bfloat16(-float(h)); // NOLINT } __device__ __forceinline__ bool operator==( const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return float(lh) == float(rh); // NOLINT } __device__ __forceinline__ bool operator!=( const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return float(lh) != float(rh); // NOLINT } __device__ __forceinline__ bool operator>( const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return float(lh) > float(rh); // NOLINT } __device__ __forceinline__ bool operator<( const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return float(lh) < float(rh); // NOLINT } __device__ __forceinline__ bool operator>=( const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return float(lh) >= float(rh); // NOLINT } __device__ __forceinline__ bool operator<=( const __nv_bfloat16& lh, const __nv_bfloat16& rh) { return float(lh) <= float(rh); // NOLINT } #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) #endif // __CUDACC__ #endif // BF16_ENABLED #endif // DGL_ARRAY_CUDA_BF16_CUH_