bf16.cuh 5.28 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
4
5
6
7
8
9
10
11
12
13
14
15
16
 *  Copyright (c) 2022 by Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
17
18
 * @file array/cuda/bf16.cuh
 * @brief bfloat16 related functions.
19
20
21
 */
#ifndef DGL_ARRAY_CUDA_BF16_CUH_
#define DGL_ARRAY_CUDA_BF16_CUH_
sangwzh's avatar
sangwzh committed
22
#include <hip/hip_runtime.h>
23
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
24
#include <hip/hip_bf16.h>
25

26
27
#include <algorithm>

sangwzh's avatar
sangwzh committed
28
29
30
static __device__ __forceinline__ __hip_bfloat16
max(__hip_bfloat16 a, __hip_bfloat16 b) {
#if defined(__HIP_DEVICE_COMPILE__)
31
32
  return __hmax(a, b);
#else
sangwzh's avatar
sangwzh committed
33
  return __hip_bfloat16(max(float(a), float(b)));  // NOLINT
34
35
36
#endif
}

sangwzh's avatar
sangwzh committed
37
38
39
static __device__ __forceinline__ __hip_bfloat16
min(__hip_bfloat16 a, __hip_bfloat16 b) {
#if defined(__HIP_DEVICE_COMPILE__) 
40
41
  return __hmin(a, b);
#else
sangwzh's avatar
sangwzh committed
42
  return __hip_bfloat16(min(float(a), float(b)));  // NOLINT
43
44
45
#endif
}

sangwz's avatar
sangwz committed
46
#if HIP_VERSION_MAJOR < 6 
47
// Arithmetic BF16 operations for architecture >= 8.0 are already defined in
sangwzh's avatar
sangwzh committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// hip/__hip_bfloat16.h
// #if defined(__DTK_ARCH__) && (__DTK_ARCH__ < 800)
// // CUDA 12.2 adds "emulated" support for older architectures.
// #if defined(DTKRT_VERSION) && (DTKRT_VERSION < 12020)
__device__ __forceinline__ __hip_bfloat16
operator+(const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
  return __hip_bfloat16(float(lh) + float(rh));  // NOLINT
}
__device__ __forceinline__ __hip_bfloat16
operator-(const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
  return __hip_bfloat16(float(lh) - float(rh));  // NOLINT
}
__device__ __forceinline__ __hip_bfloat16
operator*(const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
  return __hip_bfloat16(float(lh) * float(rh));  // NOLINT
}
__device__ __forceinline__ __hip_bfloat16
operator/(const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
  return __hip_bfloat16(float(lh) / float(rh));  // NOLINT
67
68
}

sangwzh's avatar
sangwzh committed
69
70
71
__device__ __forceinline__ __hip_bfloat16& operator+=(
    __hip_bfloat16& lh, const __hip_bfloat16& rh) {  // NOLINT
  lh = __hip_bfloat16(float(lh) + float(rh));       // NOLINT
72
  return lh;
73
}
sangwzh's avatar
sangwzh committed
74
75
76
__device__ __forceinline__ __hip_bfloat16& operator-=(
    __hip_bfloat16& lh, const __hip_bfloat16& rh) {  // NOLINT
  lh = __hip_bfloat16(float(lh) - float(rh));       // NOLINT
77
  return lh;
78
}
sangwzh's avatar
sangwzh committed
79
80
81
__device__ __forceinline__ __hip_bfloat16& operator*=(
    __hip_bfloat16& lh, const __hip_bfloat16& rh) {  // NOLINT
  lh = __hip_bfloat16(float(lh) * float(rh));       // NOLINT
82
  return lh;
83
}
sangwzh's avatar
sangwzh committed
84
85
86
__device__ __forceinline__ __hip_bfloat16& operator/=(
    __hip_bfloat16& lh, const __hip_bfloat16& rh) {  // NOLINT
  lh = __hip_bfloat16(float(lh) / float(rh));       // NOLINT
87
  return lh;
88
89
}

sangwzh's avatar
sangwzh committed
90
91
92
__device__ __forceinline__ __hip_bfloat16& operator++(
    __hip_bfloat16& h) {                // NOLINT
  h = __hip_bfloat16(float(h) + 1.0f);  // NOLINT
93
  return h;
94
}
sangwzh's avatar
sangwzh committed
95
96
97
__device__ __forceinline__ __hip_bfloat16& operator--(
    __hip_bfloat16& h) {                // NOLINT
  h = __hip_bfloat16(float(h) - 1.0f);  // NOLINT
98
  return h;
99
}
sangwzh's avatar
sangwzh committed
100
101
102
103
__device__ __forceinline__ __hip_bfloat16
operator++(__hip_bfloat16& h, int) {  // NOLINT
  __hip_bfloat16 ret = h;
  h = __hip_bfloat16(float(h) + 1.0f);  // NOLINT
104
  return ret;
105
}
sangwzh's avatar
sangwzh committed
106
107
108
109
__device__ __forceinline__ __hip_bfloat16
operator--(__hip_bfloat16& h, int) {  // NOLINT
  __hip_bfloat16 ret = h;
  h = __hip_bfloat16(float(h) - 1.0f);  // NOLINT
110
  return ret;
111
112
}

sangwzh's avatar
sangwzh committed
113
__device__ __forceinline__ __hip_bfloat16 operator+(const __hip_bfloat16& h) {
114
115
  return h;
}
sangwzh's avatar
sangwzh committed
116
117
__device__ __forceinline__ __hip_bfloat16 operator-(const __hip_bfloat16& h) {
  return __hip_bfloat16(-float(h));  // NOLINT
118
119
}

120
__device__ __forceinline__ bool operator==(
sangwzh's avatar
sangwzh committed
121
    const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
122
123
  return float(lh) == float(rh);  // NOLINT
}
124
__device__ __forceinline__ bool operator!=(
sangwzh's avatar
sangwzh committed
125
    const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
126
127
  return float(lh) != float(rh);  // NOLINT
}
128
__device__ __forceinline__ bool operator>(
sangwzh's avatar
sangwzh committed
129
    const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
130
  return float(lh) > float(rh);  // NOLINT
131
}
132
__device__ __forceinline__ bool operator<(
sangwzh's avatar
sangwzh committed
133
    const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
134
  return float(lh) < float(rh);  // NOLINT
135
}
136
__device__ __forceinline__ bool operator>=(
sangwzh's avatar
sangwzh committed
137
    const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
138
139
  return float(lh) >= float(rh);  // NOLINT
}
140
__device__ __forceinline__ bool operator<=(
sangwzh's avatar
sangwzh committed
141
    const __hip_bfloat16& lh, const __hip_bfloat16& rh) {
142
143
  return float(lh) <= float(rh);  // NOLINT
}
sangwzh's avatar
sangwzh committed
144
145
// #endif  // defined(DTKRT_VERSION) && (DTKRT_VERSION < 12020)
// #endif  // defined(__DTK_ARCH__) && (__DTK_ARCH__ < 800)
sangwz's avatar
sangwz committed
146
147
#endif
#if __HIPCC__
sangwzh's avatar
sangwzh committed
148
149
150
151
152
153
154
155
156
157
__device__
inline
__hip_bfloat16 __shfl_down(__hip_bfloat16 var, unsigned int lane_delta, int width = warpSize) {
    union { unsigned short s; __hip_bfloat16 us; } tmp;
    tmp.us = var;
    tmp.s = __shfl_down(tmp.s, lane_delta, width);
    return tmp.us;
}
#endif  // __HIPCC__

158
159
160
161

#endif  // BF16_ENABLED

#endif  // DGL_ARRAY_CUDA_BF16_CUH_