atomic.cuh 5.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/*!
 *  Copyright (c) 2019 by Contributors
 * \file kernel/cuda/atomic.cuh
 * \brief Atomic functions
 */
#ifndef DGL_KERNEL_CUDA_ATOMIC_H_
#define DGL_KERNEL_CUDA_ATOMIC_H_

#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif

namespace dgl {
namespace kernel {
namespace cuda {

// Type trait for selecting code type
template <int Bytes> struct Code { };

template <> struct Code<4> {
  typedef unsigned int Type;
};

template <> struct Code<8> {
  typedef unsigned long long int Type;
};

// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

template <> struct Cast<float> {
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

template <> struct Cast<double> {
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

#define DEFINE_ATOMIC(NAME) \
  template <typename T>                                          \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) {    \
    typedef typename Cast<T>::Type CT;                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCAS(addr_as_ui, assumed,                       \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));       \
    } while (assumed != old);                                    \
    return Cast<T>::Decode(old);                                 \
  }

#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

#if __CUDA_ARCH__ >= 200
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
  return atomicAdd(addr, val);
}
#endif  // __CUDA_ARCH__

#if __CUDA_ARCH__ >= 600
template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
  return atomicAdd(addr, val);
}
#endif

#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#if __CUDA_ARCH__ >= 600
template <>
__device__ __forceinline__ __half2 AtomicAdd<__half2>(__half2* addr, __half2 val) {
  return atomicAdd(addr, val);
}
#endif  // __CUDA_ARCH__

#if __CUDA_ARCH__ >= 700
template <>
__device__ __forceinline__ __half AtomicAdd<__half>(__half* addr, __half val) {
  return atomicAdd(addr, val);
}
#endif  // __CUDA_ARCH__
#endif

#define OP(a, b) a * b
DEFINE_ATOMIC(Mul)
#undef OP

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
/**
* \brief Performs an atomic compare-and-swap on 64 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* \param address The address to perform the atomic operation on.
* \param compare The value to compare to.
* \param val The new value to conditionally store.
*
* \return The old value at the address.
*/
inline __device__ int64_t AtomicCAS(
    int64_t * const address,
    const int64_t compare,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

/**
* \brief Performs an atomic compare-and-swap on 32 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* \param address The address to perform the atomic operation on.
* \param compare The value to compare to.
* \param val The new value to conditionally store.
*
* \return The old value at the address.
*/
inline __device__ int32_t AtomicCAS(
    int32_t * const address,
    const int32_t compare,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
inline __device__ int64_t AtomicMax(
    int64_t * const address,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}

inline __device__ int32_t AtomicMax(
    int32_t * const address,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}

197
198
199
200
201
}  // namespace cuda
}  // namespace kernel
}  // namespace dgl

#endif  // DGL_KERNEL_CUDA_ATOMIC_H_