bf16.cuh 4.86 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright (c) 2022 by Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file array/cuda/bf16.cuh
 * @brief bfloat16 related functions.
18
19
20
21
22
23
 */
#ifndef DGL_ARRAY_CUDA_BF16_CUH_
#define DGL_ARRAY_CUDA_BF16_CUH_

#if BF16_ENABLED
#include <cuda_bf16.h>
24

25
26
#include <algorithm>

27
28
static __device__ __forceinline__ __nv_bfloat16
max(__nv_bfloat16 a, __nv_bfloat16 b) {
29
30
31
32
33
34
35
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return __hmax(a, b);
#else
  return __nv_bfloat16(max(float(a), float(b)));  // NOLINT
#endif
}

36
37
static __device__ __forceinline__ __nv_bfloat16
min(__nv_bfloat16 a, __nv_bfloat16 b) {
38
39
40
41
42
43
44
45
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return __hmin(a, b);
#else
  return __nv_bfloat16(min(float(a), float(b)));  // NOLINT
#endif
}

#ifdef __CUDACC__
46
47
// Arithmetic BF16 operations for architecture >= 8.0 are already defined in
// cuda_bf16.h
48
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
49
50
// CUDA 12.2 adds "emulated" support for older architectures.
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
51
52
__device__ __forceinline__ __nv_bfloat16
operator+(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
53
54
  return __nv_bfloat16(float(lh) + float(rh));  // NOLINT
}
55
56
__device__ __forceinline__ __nv_bfloat16
operator-(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
57
58
  return __nv_bfloat16(float(lh) - float(rh));  // NOLINT
}
59
60
__device__ __forceinline__ __nv_bfloat16
operator*(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
61
62
  return __nv_bfloat16(float(lh) * float(rh));  // NOLINT
}
63
64
__device__ __forceinline__ __nv_bfloat16
operator/(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
65
66
67
68
69
  return __nv_bfloat16(float(lh) / float(rh));  // NOLINT
}

__device__ __forceinline__ __nv_bfloat16& operator+=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
70
71
  lh = __nv_bfloat16(float(lh) + float(rh));       // NOLINT
  return lh;
72
73
74
}
__device__ __forceinline__ __nv_bfloat16& operator-=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
75
76
  lh = __nv_bfloat16(float(lh) - float(rh));       // NOLINT
  return lh;
77
78
79
}
__device__ __forceinline__ __nv_bfloat16& operator*=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
80
81
  lh = __nv_bfloat16(float(lh) * float(rh));       // NOLINT
  return lh;
82
83
84
}
__device__ __forceinline__ __nv_bfloat16& operator/=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
85
86
  lh = __nv_bfloat16(float(lh) / float(rh));       // NOLINT
  return lh;
87
88
}

89
90
91
92
__device__ __forceinline__ __nv_bfloat16& operator++(
    __nv_bfloat16& h) {                // NOLINT
  h = __nv_bfloat16(float(h) + 1.0f);  // NOLINT
  return h;
93
}
94
95
96
97
__device__ __forceinline__ __nv_bfloat16& operator--(
    __nv_bfloat16& h) {                // NOLINT
  h = __nv_bfloat16(float(h) - 1.0f);  // NOLINT
  return h;
98
}
99
100
101
102
103
__device__ __forceinline__ __nv_bfloat16
operator++(__nv_bfloat16& h, int) {  // NOLINT
  __nv_bfloat16 ret = h;
  h = __nv_bfloat16(float(h) + 1.0f);  // NOLINT
  return ret;
104
}
105
106
107
108
109
__device__ __forceinline__ __nv_bfloat16
operator--(__nv_bfloat16& h, int) {  // NOLINT
  __nv_bfloat16 ret = h;
  h = __nv_bfloat16(float(h) - 1.0f);  // NOLINT
  return ret;
110
111
112
113
114
115
116
117
118
}

__device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) {
  return h;
}
__device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) {
  return __nv_bfloat16(-float(h));  // NOLINT
}

119
120
__device__ __forceinline__ bool operator==(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
121
122
  return float(lh) == float(rh);  // NOLINT
}
123
124
__device__ __forceinline__ bool operator!=(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
125
126
  return float(lh) != float(rh);  // NOLINT
}
127
128
129
__device__ __forceinline__ bool operator>(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
  return float(lh) > float(rh);  // NOLINT
130
}
131
132
133
__device__ __forceinline__ bool operator<(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
  return float(lh) < float(rh);  // NOLINT
134
}
135
136
__device__ __forceinline__ bool operator>=(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
137
138
  return float(lh) >= float(rh);  // NOLINT
}
139
140
__device__ __forceinline__ bool operator<=(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
141
142
  return float(lh) <= float(rh);  // NOLINT
}
143
#endif  // defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
144
145
146
147
148
149
#endif  // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#endif  // __CUDACC__

#endif  // BF16_ENABLED

#endif  // DGL_ARRAY_CUDA_BF16_CUH_