atomics.cuh 12.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#define ATOMIC(NAME)                                                           \
  template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl;    \
                                                                               \
  template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> {     \
    inline __device__ void operator()(scalar *address, scalar val) {           \
      uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
      uint32_t old = *address_as_ui;                                           \
      uint32_t shift = ((size_t)address & 3) * 8;                              \
      uint32_t sum;                                                            \
      uint32_t assumed;                                                        \
                                                                               \
      do {                                                                     \
        assumed = old;                                                         \
        sum = OP(val, scalar((old >> shift) & 0xff));                          \
        old = (old & ~(0x000000ff << shift)) | (sum << shift);                 \
        old = atomicCAS(address_as_ui, assumed, old);                          \
      } while (assumed != old);                                                \
    }                                                                          \
  };                                                                           \
                                                                               \
  template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> {     \
    inline __device__ void operator()(scalar *address, scalar val) {           \
      uint32_t *address_as_ui =                                                \
          (uint32_t *)((char *)address - ((size_t)address & 2));               \
      uint32_t old = *address_as_ui;                                           \
      uint32_t sum;                                                            \
      uint32_t newval;                                                         \
      uint32_t assumed;                                                        \
                                                                               \
      do {                                                                     \
        assumed = old;                                                         \
        sum = OP(val, (size_t)address & 2 ? scalar(old >> 16)                  \
                                          : scalar(old & 0xffff));             \
        newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16)            \
                                     : (old & 0xffff0000) | sum;               \
        old = atomicCAS(address_as_ui, assumed, newval);                       \
      } while (assumed != old);                                                \
    }                                                                          \
  };                                                                           \
                                                                               \
  template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> {     \
    inline __device__ void operator()(scalar *address, scalar val) {           \
      uint32_t *address_as_ui = (uint32_t *)address;                           \
      uint32_t old = *address_as_ui;                                           \
      uint32_t assumed;                                                        \
                                                                               \
      do {                                                                     \
        assumed = old;                                                         \
        old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old));         \
      } while (assumed != old);                                                \
    }                                                                          \
  };                                                                           \
                                                                               \
  template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> {     \
    inline __device__ void operator()(scalar *address, scalar val) {           \
      unsigned long long *address_as_ull = (unsigned long long *)address;      \
      unsigned long long old = *address_as_ull;                                \
      unsigned long long assumed;                                              \
                                                                               \
      do {                                                                     \
        assumed = old;                                                         \
        old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old));        \
      } while (assumed != old);                                                \
    }                                                                          \
  };                                                                           \
                                                                               \
  template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl;    \
                                                                               \
  template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> {     \
    inline __device__ void operator()(scalar *address, scalar val) {           \
      int *address_as_i = (int *)address;                                      \
      int old = *address_as_i;                                                 \
      int assumed;                                                             \
                                                                               \
      do {                                                                     \
        assumed = old;                                                         \
        old = atomicCAS(address_as_i, assumed,                                 \
                        __float_as_int(OP(val, __int_as_float(assumed))));     \
      } while (assumed != old);                                                \
    }                                                                          \
  };                                                                           \
                                                                               \
  template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> {     \
    inline __device__ void operator()(scalar *address, scalar val) {           \
      unsigned long long int *address_as_ull =                                 \
          (unsigned long long int *)address;                                   \
      unsigned long long int old = *address_as_ull;                            \
      unsigned long long int assumed;                                          \
                                                                               \
      do {                                                                     \
        assumed = old;                                                         \
        old = atomicCAS(                                                       \
            address_as_ull, assumed,                                           \
            __double_as_longlong(OP(val, __longlong_as_double(assumed))));     \
      } while (assumed != old);                                                \
    }                                                                          \
  };

#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
  AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
  AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
  AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
  atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
  AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomAdd(float *address, float val) {
  atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
  AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
  atomicAdd(address, val);
}
#endif

#define OP(X, Y) Y *X
ATOMIC(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
  AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMul(int8_t *address, int8_t val) {
  AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMul(int16_t *address, int16_t val) {
  AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMul(int32_t *address, int32_t val) {
  AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomMul(int64_t *address, int64_t val) {
  AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMul(float *address, float val) {
  AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
  AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}

#define OP(X, Y) Y / X
ATOMIC(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
  AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
  AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
  AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
  AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
  AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
  AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomDiv(double *address, double val) {
  AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}

#define OP(X, Y) max(Y, X)
ATOMIC(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
  AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMax(int8_t *address, int8_t val) {
  AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMax(int16_t *address, int16_t val) {
  AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMax(int32_t *address, int32_t val) {
  atomicMax(address, val);
}
static inline __device__ void atomMax(int64_t *address, int64_t val) {
  AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
  AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMax(double *address, double val) {
  AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}

#define OP(X, Y) min(Y, X)
ATOMIC(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
  AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMin(int8_t *address, int8_t val) {
  AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMin(int16_t *address, int16_t val) {
  AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMin(int32_t *address, int32_t val) {
  atomicMin(address, val);
}
static inline __device__ void atomMin(int64_t *address, int64_t val) {
  AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
  AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMin(double *address, double val) {
  AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}