atomic.cuh 4.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cuda/atomic.cuh
 * \brief Atomic functions
 */
#ifndef DGL_ARRAY_CUDA_ATOMIC_H_
#define DGL_ARRAY_CUDA_ATOMIC_H_

#include <cuda_runtime.h>
10
11
12
#include <cassert>
#include "fp16.cuh"

13
14
15
16
17
18
19
20

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
template <int Bytes> struct Code { };

21
22
23
24
template <> struct Code<2> {
  typedef unsigned short int Type;
};

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
template <> struct Code<4> {
  typedef unsigned int Type;
};

template <> struct Code<8> {
  typedef unsigned long long int Type;
};

// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

44
45
46
47
48
49
50
51
52
53
54
55
#ifdef USE_FP16
template <> struct Cast<half> {
  typedef Code<sizeof(half)>::Type Type;
  static __device__ __forceinline__ Type Encode(half val) {
    return __half_as_ushort(val);
  }
  static __device__ __forceinline__ half Decode(Type code) {
    return __ushort_as_half(code);
  }
};
#endif

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
template <> struct Cast<float> {
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

template <> struct Cast<double> {
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

76
77
78
79
80
81
82
83
84
85
86
87
static __device__ __forceinline__ unsigned short int atomicCASshort(
    unsigned short int *address,
    unsigned short int compare,
    unsigned short int val) {
#if (defined(CUDART_VERSION) && (CUDART_VERSION > 10000))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
  return atomicCAS(address, compare, val);
#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
#endif  // (defined(CUDART_VERSION) && (CUDART_VERSION > 10000))
  return val;
}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#define DEFINE_ATOMIC(NAME) \
  template <typename T>                                          \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) {    \
    typedef typename Cast<T>::Type CT;                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCAS(addr_as_ui, assumed,                       \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));       \
    } while (assumed != old);                                    \
    return Cast<T>::Decode(old);                                 \
  }

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#define DEFINE_ATOMIC_HALF(NAME) \
  template <>                                                    \
  __device__ __forceinline__ half Atomic##NAME<half>(half* addr, half val) {  \
    typedef unsigned short int CT;                               \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCASshort(addr_as_ui, assumed,                  \
          Cast<half>::Encode(OP(val, Cast<half>::Decode(old)))); \
    } while (assumed != old);                                    \
    return Cast<half>::Decode(old);                              \
  }

118
119
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
120
121
122
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Max)
#endif  // USE_FP16
123
124
125
126
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
127
128
129
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Min)
#endif  // USE_FP16
130
131
132
133
134
135
136
137
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
138
#if __CUDA_ARCH__ >= 200
139
  return atomicAdd(addr, val);
140
141
#else
  return *addr + val;
142
#endif  // __CUDA_ARCH__
143
}
144
145
146

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
147
#if __CUDA_ARCH__ >= 600
148
  return atomicAdd(addr, val);
149
150
#else
  return *addr + val;
151
#endif
152
}
153

154
#ifdef USE_FP16
155
156
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
157
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
158
159
#if __CUDA_ARCH__ >= 700
  return atomicAdd(addr, val);
160
161
#else
  return *addr + val;
162
#endif  // __CUDA_ARCH__
163
164
165
}
#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif  // USE_FP16
166
167
168
169
170
171
172


}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CUDA_ATOMIC_H_