THCAtomics.cuh 2.9 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
template <typename T, size_t n>
struct AtomicMaxIntegerImpl;

template<typename T>
struct AtomicMaxIntegerImpl<T, 1> {
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
7
    uint32_t * address_as_ui = (uint32_t *) (address - ((size_t) address & 3));
rusty1s's avatar
rusty1s committed
8
    uint32_t old = *address_as_ui;
rusty1s's avatar
rusty1s committed
9
    uint32_t shift = (((size_t) address & 3) * 8);
rusty1s's avatar
rusty1s committed
10
11
12
13
14
    uint32_t sum;
    uint32_t assumed;

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
15
      sum = max(val, T((old >> shift) & 0xff));
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
24
      old = (old & ~(0x000000ff << shift)) | (sum << shift);
      old = atomicCAS(address_as_ui, assumed, old);
    } while (assumed != old);
  }
};

template<typename T>
struct AtomicMaxIntegerImpl<T, 2> {
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
25
    uint32_t * address_as_ui = (uint32_t *) ((char *) address - ((size_t) address & 2));
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
    uint32_t old = *address_as_ui;
    uint32_t sum;
    uint32_t newval;
    uint32_t assumed;

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
33
      sum = max(val, (size_t)address & 2 ? T(old >> 16) : T(old & 0xffff));
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
      newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) : (old & 0xffff0000) | sum;
      old = atomicCAS(address_as_ui, assumed, newval);
    } while (assumed != old);
  }
};

template<typename T>
rusty1s's avatar
rusty1s committed
41
struct AtomicMaxIntegerImpl<T, 8> {
rusty1s's avatar
rusty1s committed
42
  inline __device__ void operator()(T *address, T val) {
rusty1s's avatar
rusty1s committed
43
44
45
    unsigned long long *address_as_ull = (unsigned long long *) (address);
    unsigned long long old = *address_as_ull;
    unsigned long long assumed;
rusty1s's avatar
rusty1s committed
46
47
48

    do {
      assumed = old;
rusty1s's avatar
rusty1s committed
49
      old = atomicCAS(address_as_ull, assumed, max(val, (T) old));
rusty1s's avatar
rusty1s committed
50
51
52
53
    } while (assumed != old);
  }
};

rusty1s's avatar
rusty1s committed
54
55
56
static inline __device__ void atomicMax(uint8_t *address, uint8_t val) {
  AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
60
static inline __device__ void atomicMax(int8_t *address, int8_t val) {
  AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
rusty1s committed
62
63
64
static inline __device__ void atomicMax(int16_t *address, int16_t val) {
  AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
typos  
rusty1s committed
66
67
68
static inline __device__ void atomicMax(int64_t *address, int64_t val) {
  AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
rusty1s's avatar
rusty1s committed
69
70

#ifdef CUDA_HALF_TENSOR
rusty1s's avatar
typos  
rusty1s committed
71
static inline __device__ void atomicMax(half *address, half val) {}
rusty1s's avatar
rusty1s committed
72
73
74
75
76
77
78
79
80
81
82
83
84
#endif

static inline __device__ void atomicMax(float *address, float val) {
  int *address_as_i = (int *) address;
  int old = *address_as_i;
  int assumed;

  do {
    assumed = old;
    old = atomicCAS(address_as_i, assumed, __float_as_int(max(val, __int_as_float(assumed))));
  } while (assumed != old);
}

rusty1s's avatar
typos  
rusty1s committed
85
static inline __device__  void atomicMax(double *address, double val) {
rusty1s's avatar
rusty1s committed
86
87
88
89
90
91
92
93
94
  unsigned long long int *address_as_ull = (unsigned long long int *) address;
  unsigned long long int old = *address_as_ull;
  unsigned long long int assumed;

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(max(val, __longlong_as_double(assumed))));
  } while (assumed != old);
}