fp16.cuh 3.68 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/fp16.cuh
 * \brief float16 related functions.
 * \note this file is modified from TVM project:
 *       https://github.com/apache/tvm/blob/e561007f0c330e3d14c2bc8a3ef40fb741db9004/src/target/source/literal/cuda_half_t.h.
 */
8
9
#ifndef DGL_ARRAY_CUDA_FP16_CUH_
#define DGL_ARRAY_CUDA_FP16_CUH_
10
11
12
13


#ifdef USE_FP16
#include <cuda_fp16.h>
14
#include <algorithm>
15

16
static __device__ __forceinline__ half max(half a, half b) {
17
18
19
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hgt(__half(a), __half(b)) ? a : b;
#else
20
  return __half(max(float(a), float(b)));  // NOLINT
21
22
23
#endif
}

24
static __device__ __forceinline__ half min(half a, half b) {
25
26
27
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hlt(__half(a), __half(b)) ? a : b;
#else
28
  return __half(min(float(a), float(b)));  // NOLINT
29
30
#endif
}
31
32
33
34

#ifdef __CUDACC__
// Arithmetic FP16 operations for architecture >= 5.3 are already defined in cuda_fp16.h
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)
35
36
37
38
39
40
41
42
43
44
45
46
__device__ __forceinline__ __half operator+(const __half& lh, const __half& rh) {
  return __half(float(lh) + float(rh));  // NOLINT
}
__device__ __forceinline__ __half operator-(const __half& lh, const __half& rh) {
  return __half(float(lh) - float(rh));  // NOLINT
}
__device__ __forceinline__ __half operator*(const __half& lh, const __half& rh) {
  return __half(float(lh) * float(rh));  // NOLINT
}
__device__ __forceinline__ __half operator/(const __half& lh, const __half& rh) {
  return __half(float(lh) / float(rh));  // NOLINT
}
47

48
49
50
51
52
53
54
55
56
57
58
59
__device__ __forceinline__ __half& operator+=(__half& lh, const __half& rh) {  // NOLINT
  lh = __half(float(lh) + float(rh)); return lh;  // NOLINT
}
__device__ __forceinline__ __half& operator-=(__half& lh, const __half& rh) {  // NOLINT
  lh = __half(float(lh) - float(rh)); return lh;  // NOLINT
}
__device__ __forceinline__ __half& operator*=(__half& lh, const __half& rh) {  // NOLINT
  lh = __half(float(lh) * float(rh)); return lh;  // NOLINT
}
__device__ __forceinline__ __half& operator/=(__half& lh, const __half& rh) {  // NOLINT
  lh = __half(float(lh) / float(rh)); return lh;  // NOLINT
}
60

61
62
63
64
65
66
67
68
69
70
71
72
__device__ __forceinline__ __half& operator++(__half& h) {  // NOLINT
  h = __half(float(h) + 1.0f); return h;  // NOLINT
}
__device__ __forceinline__ __half& operator--(__half& h) {  // NOLINT
  h = __half(float(h) - 1.0f); return h;  // NOLINT
}
__device__ __forceinline__ __half  operator++(__half& h, int) {  // NOLINT
  __half ret = h; h = __half(float(h) + 1.0f); return ret;  // NOLINT
}
__device__ __forceinline__ __half  operator--(__half& h, int) {  // NOLINT
  __half ret = h; h = __half(float(h) - 1.0f); return ret;  // NOLINT
}
73
74

__device__ __forceinline__ __half operator+(const __half& h) { return h; }
75
76
77
__device__ __forceinline__ __half operator-(const __half& h) {
  return __half(-float(h));  // NOLINT
}
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
__device__ __forceinline__ bool operator==(const __half& lh, const __half& rh) {
  return float(lh) == float(rh);  // NOLINT
}
__device__ __forceinline__ bool operator!=(const __half& lh, const __half& rh) {
  return float(lh) != float(rh);  // NOLINT
}
__device__ __forceinline__ bool operator> (const __half& lh, const __half& rh) {
  return float(lh) >  float(rh);  // NOLINT
}
__device__ __forceinline__ bool operator< (const __half& lh, const __half& rh) {
  return float(lh) <  float(rh);  // NOLINT
}
__device__ __forceinline__ bool operator>=(const __half& lh, const __half& rh) {
  return float(lh) >= float(rh);  // NOLINT
}
__device__ __forceinline__ bool operator<=(const __half& lh, const __half& rh) {
  return float(lh) <= float(rh);  // NOLINT
}
97
98
99
#endif  // __CUDA_ARCH__ < 530
#endif  // __CUDACC__

100
101
#endif  // USE_FP16

102
#endif  // DGL_ARRAY_CUDA_FP16_CUH_