Commit cc561ac4 authored by rusty1s's avatar rusty1s
Browse files

atomic max

parent 88b7f56b
...@@ -4,16 +4,15 @@ struct AtomicMaxIntegerImpl; ...@@ -4,16 +4,15 @@ struct AtomicMaxIntegerImpl;
template<typename T> template<typename T>
struct AtomicMaxIntegerImpl<T, 1> { struct AtomicMaxIntegerImpl<T, 1> {
inline __device__ void operator()(T *address, T val) { inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui = uint32_t * address_as_ui = (uint32_t *) (address - ((size_t) address & 3));
(uint32_t *) (address - ((size_t)address & 3));
uint32_t old = *address_as_ui; uint32_t old = *address_as_ui;
uint32_t shift = (((size_t)address & 3) * 8); uint32_t shift = (((size_t) address & 3) * 8);
uint32_t sum; uint32_t sum;
uint32_t assumed; uint32_t assumed;
do { do {
assumed = old; assumed = old;
sum = val + T((old >> shift) & 0xff); sum = max(val, T((old >> shift) & 0xff));
old = (old & ~(0x000000ff << shift)) | (sum << shift); old = (old & ~(0x000000ff << shift)) | (sum << shift);
old = atomicCAS(address_as_ui, assumed, old); old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old); } while (assumed != old);
...@@ -23,8 +22,7 @@ struct AtomicMaxIntegerImpl<T, 1> { ...@@ -23,8 +22,7 @@ struct AtomicMaxIntegerImpl<T, 1> {
template<typename T> template<typename T>
struct AtomicMaxIntegerImpl<T, 2> { struct AtomicMaxIntegerImpl<T, 2> {
inline __device__ void operator()(T *address, T val) { inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui = uint32_t * address_as_ui = (uint32_t *) ((char *) address - ((size_t) address & 2));
(uint32_t *) ((char *)address - ((size_t)address & 2));
uint32_t old = *address_as_ui; uint32_t old = *address_as_ui;
uint32_t sum; uint32_t sum;
uint32_t newval; uint32_t newval;
...@@ -32,7 +30,7 @@ struct AtomicMaxIntegerImpl<T, 2> { ...@@ -32,7 +30,7 @@ struct AtomicMaxIntegerImpl<T, 2> {
do { do {
assumed = old; assumed = old;
sum = val + (size_t)address & 2 ? T(old >> 16) : T(old & 0xffff); sum = max(val, (size_t)address & 2 ? T(old >> 16) : T(old & 0xffff));
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) : (old & 0xffff0000) | sum; newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) : (old & 0xffff0000) | sum;
old = atomicCAS(address_as_ui, assumed, newval); old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old); } while (assumed != old);
...@@ -40,43 +38,30 @@ struct AtomicMaxIntegerImpl<T, 2> { ...@@ -40,43 +38,30 @@ struct AtomicMaxIntegerImpl<T, 2> {
}; };
template<typename T> template<typename T>
struct AtomicMaxIntegerImpl<T, 4> { struct AtomicMaxIntegerImpl<T, 8> {
inline __device__ void operator()(T *address, T val) { inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui = (uint32_t *) (address); unsigned long long *address_as_ull = (unsigned long long *) (address);
uint32_t old = *address_as_ui; unsigned long long old = *address_as_ull;
uint32_t newval; unsigned long long assumed;
uint32_t assumed;
do { do {
assumed = old; assumed = old;
newval = val + (T)old; old = atomicCAS(address_as_ull, assumed, max(val, (T) old));
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old); } while (assumed != old);
} }
}; };
template<typename T> static inline __device__ void atomicMax(uint8_t *address, uint8_t val) {
struct AtomicMaxIntegerImpl<T, 8> { AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
inline __device__ void operator()(T *address, T val) { }
int *address_as_ull = (int*) (address);
int newval = *address_as_ull;
atomicMax(address_as_ull, newval);
/* unsigned long long newval; */
/* unsigned long long assumed; */
/* do { */
/* assumed = old; */
/* newval = val + (T)old; */
/* old = atomicCAS(address_as_ui, assumed, newval); */
/* } while (assumed != old); */
}
};
static inline __device__ void atomicMax(uint8_t *address, uint8_t val) {}
static inline __device__ void atomicMax(int8_t *address, int8_t val) {} static inline __device__ void atomicMax(int8_t *address, int8_t val) {
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomicMax(int16_t *address, int16_t val) {} static inline __device__ void atomicMax(int16_t *address, int16_t val) {
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomicMax(int64_t *address, int64_t val) { static inline __device__ void atomicMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment