THCAtomics.cuh 3.69 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#define OP(X, Y) max(X, Y)

rusty1s's avatar
rusty1s committed
3
template <typename T, size_t n>
rusty1s's avatar
rusty1s committed
4
struct AtomicIntegerImpl;
rusty1s's avatar
rusty1s committed
5
6

template<typename T>
rusty1s's avatar
rusty1s committed
7
struct AtomicIntegerImpl<T, 1> {
rusty1s's avatar
rusty1s committed
8
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
9
    uint32_t * address_as_ui = (uint32_t *) (address - ((size_t) address & 3));
rusty1s's avatar
rusty1s committed
10
    uint32_t old = *address_as_ui;
rusty1s's avatar
rusty1s committed
11
    uint32_t shift = ((size_t) address & 3) * 8;
rusty1s's avatar
rusty1s committed
12
    uint32_t res;
rusty1s's avatar
rusty1s committed
13
14
15
16
    uint32_t assumed;

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
17
18
      res = OP(val, T((old >> shift) & 0xff));
      old = (old & ~(0x000000ff << shift)) | (res << shift);
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
      old = atomicCAS(address_as_ui, assumed, old);
    } while (assumed != old);
  }
};

template<typename T>
rusty1s's avatar
rusty1s committed
25
struct AtomicIntegerImpl<T, 2> {
rusty1s's avatar
rusty1s committed
26
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
27
    uint32_t * address_as_ui = (uint32_t *) ((char *) address - ((size_t) address & 2));
rusty1s's avatar
rusty1s committed
28
    uint32_t old = *address_as_ui;
rusty1s's avatar
rusty1s committed
29
    uint32_t res;
rusty1s's avatar
rusty1s committed
30
31
32
33
34
    uint32_t newval;
    uint32_t assumed;

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
35
36
      res = OP(val, (size_t) address & 2 ? T(old >> 16) : T(old & 0xffff));
      newval = (size_t) address & 2 ? (old & 0xffff) | (res << 16) : (old & 0xffff0000) | res;
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
      old = atomicCAS(address_as_ui, assumed, newval);
    } while (assumed != old);
  }
};

template<typename T>
rusty1s's avatar
rusty1s committed
43
44
struct AtomicIntegerImpl<T, 4> {
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
45
46
    uint32_t *address_as_ui = (uint32_t *) address;
    uint32_t old = *address_as_ui;
rusty1s's avatar
rusty1s committed
47
48
49
50
    uint32_t assumed;

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
51
      old = atomicCAS(address_as_ui, assumed, OP(val, (T) old));
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
    } while (assumed != old);
  }
};

template<typename T>
struct AtomicIntegerImpl<T, 8> {
rusty1s's avatar
rusty1s committed
58
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
59
    unsigned long long *address_as_ull = (unsigned long long *) address;
rusty1s's avatar
rusty1s committed
60
61
    unsigned long long old = *address_as_ull;
    unsigned long long assumed;
rusty1s's avatar
rusty1s committed
62
63
64

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
65
      old = atomicCAS(address_as_ull, assumed, OP(val, (T) old));
rusty1s's avatar
rusty1s committed
66
67
68
69
    } while (assumed != old);
  }
};

rusty1s's avatar
rusty1s committed
70
71
template <typename T, size_t n>
struct AtomicDecimalImpl;
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
typo  
rusty1s committed
73
74
75
template <typename T>
struct AtomicDecimalImpl<T, 4> {
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
82
83
84
85
    int *address_as_i = (int *) address;
    int old = *address_as_i;
    int assumed;

    do {
      assumed = old;
      old = atomicCAS(address_as_i, assumed, __float_as_int(OP(val, __int_as_float(assumed))));
    } while (assumed != old);
  }
};
rusty1s's avatar
rusty1s committed
86

rusty1s's avatar
typo  
rusty1s committed
87
88
89
template <typename T>
struct AtomicDecimalImpl<T, 8> {
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
90
91
92
    unsigned long long int *address_as_ull = (unsigned long long int *) address;
    unsigned long long int old = *address_as_ull;
    unsigned long long int assumed;
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
rusty1s committed
94
95
96
97
98
99
    do {
      assumed = old;
      old = atomicCAS(address_as_ull, assumed, __double_as_longlong(OP(val, __longlong_as_double(assumed))));
    } while (assumed != old);
  }
};
rusty1s's avatar
rusty1s committed
100

rusty1s's avatar
rusty1s committed
101
static inline __device__ void atomicMax(uint8_t *address, uint8_t val) { AtomicIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
rusty1s's avatar
typo  
rusty1s committed
102
static inline __device__ void atomicMax( int8_t *address,  int8_t val) { AtomicIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
rusty1s's avatar
rusty1s committed
103
104
static inline __device__ void atomicMax(int16_t *address, int16_t val) { AtomicIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomicMax(int64_t *address, int64_t val) { AtomicIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
rusty1s's avatar
typo  
rusty1s committed
105
106
static inline __device__ void atomicMax(  float *address,   float val) { AtomicDecimalImpl<  float, sizeof(  float)>()(address, val); }
static inline __device__ void atomicMax( double *address,  double val) { AtomicDecimalImpl< double, sizeof( double)>()(address, val); }
rusty1s's avatar
rusty1s committed
107
#ifdef CUDA_HALF_TENSOR
rusty1s's avatar
typo  
rusty1s committed
108
static inline __device__ void atomicMax(   half *address,    half val) {}
rusty1s's avatar
rusty1s committed
109
#endif