bf16.cuh 4.68 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright (c) 2022 by Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file array/cuda/bf16.cuh
 * @brief bfloat16 related functions.
18
19
20
21
22
23
 */
#ifndef DGL_ARRAY_CUDA_BF16_CUH_
#define DGL_ARRAY_CUDA_BF16_CUH_

#if BF16_ENABLED
#include <cuda_bf16.h>
24

25
26
#include <algorithm>

27
28
static __device__ __forceinline__ __nv_bfloat16
max(__nv_bfloat16 a, __nv_bfloat16 b) {
29
30
31
32
33
34
35
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return __hmax(a, b);
#else
  return __nv_bfloat16(max(float(a), float(b)));  // NOLINT
#endif
}

36
37
static __device__ __forceinline__ __nv_bfloat16
min(__nv_bfloat16 a, __nv_bfloat16 b) {
38
39
40
41
42
43
44
45
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return __hmin(a, b);
#else
  return __nv_bfloat16(min(float(a), float(b)));  // NOLINT
#endif
}

#ifdef __CUDACC__
46
47
// Arithmetic BF16 operations for architecture >= 8.0 are already defined in
// cuda_bf16.h
48
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
49
50
__device__ __forceinline__ __nv_bfloat16
operator+(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
51
52
  return __nv_bfloat16(float(lh) + float(rh));  // NOLINT
}
53
54
__device__ __forceinline__ __nv_bfloat16
operator-(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
55
56
  return __nv_bfloat16(float(lh) - float(rh));  // NOLINT
}
57
58
__device__ __forceinline__ __nv_bfloat16
operator*(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
59
60
  return __nv_bfloat16(float(lh) * float(rh));  // NOLINT
}
61
62
__device__ __forceinline__ __nv_bfloat16
operator/(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
63
64
65
66
67
  return __nv_bfloat16(float(lh) / float(rh));  // NOLINT
}

__device__ __forceinline__ __nv_bfloat16& operator+=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
68
69
  lh = __nv_bfloat16(float(lh) + float(rh));       // NOLINT
  return lh;
70
71
72
}
__device__ __forceinline__ __nv_bfloat16& operator-=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
73
74
  lh = __nv_bfloat16(float(lh) - float(rh));       // NOLINT
  return lh;
75
76
77
}
__device__ __forceinline__ __nv_bfloat16& operator*=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
78
79
  lh = __nv_bfloat16(float(lh) * float(rh));       // NOLINT
  return lh;
80
81
82
}
__device__ __forceinline__ __nv_bfloat16& operator/=(
    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT
83
84
  lh = __nv_bfloat16(float(lh) / float(rh));       // NOLINT
  return lh;
85
86
}

87
88
89
90
__device__ __forceinline__ __nv_bfloat16& operator++(
    __nv_bfloat16& h) {                // NOLINT
  h = __nv_bfloat16(float(h) + 1.0f);  // NOLINT
  return h;
91
}
92
93
94
95
__device__ __forceinline__ __nv_bfloat16& operator--(
    __nv_bfloat16& h) {                // NOLINT
  h = __nv_bfloat16(float(h) - 1.0f);  // NOLINT
  return h;
96
}
97
98
99
100
101
__device__ __forceinline__ __nv_bfloat16
operator++(__nv_bfloat16& h, int) {  // NOLINT
  __nv_bfloat16 ret = h;
  h = __nv_bfloat16(float(h) + 1.0f);  // NOLINT
  return ret;
102
}
103
104
105
106
107
__device__ __forceinline__ __nv_bfloat16
operator--(__nv_bfloat16& h, int) {  // NOLINT
  __nv_bfloat16 ret = h;
  h = __nv_bfloat16(float(h) - 1.0f);  // NOLINT
  return ret;
108
109
110
111
112
113
114
115
116
}

__device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) {
  return h;
}
__device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) {
  return __nv_bfloat16(-float(h));  // NOLINT
}

117
118
__device__ __forceinline__ bool operator==(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
119
120
  return float(lh) == float(rh);  // NOLINT
}
121
122
__device__ __forceinline__ bool operator!=(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
123
124
  return float(lh) != float(rh);  // NOLINT
}
125
126
127
__device__ __forceinline__ bool operator>(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
  return float(lh) > float(rh);  // NOLINT
128
}
129
130
131
__device__ __forceinline__ bool operator<(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
  return float(lh) < float(rh);  // NOLINT
132
}
133
134
__device__ __forceinline__ bool operator>=(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
135
136
  return float(lh) >= float(rh);  // NOLINT
}
137
138
__device__ __forceinline__ bool operator<=(
    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
139
140
141
142
143
144
145
146
  return float(lh) <= float(rh);  // NOLINT
}
#endif  // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#endif  // __CUDACC__

#endif  // BF16_ENABLED

#endif  // DGL_ARRAY_CUDA_BF16_CUH_