Commit ede5e330 authored by rusty1s's avatar rusty1s
Browse files

more generic

parent aaa5a410
......@@ -8,7 +8,7 @@ struct AtomicIntegerImpl<T, 1> {
inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui = (uint32_t *) (address - ((size_t) address & 3));
uint32_t old = *address_as_ui;
uint32_t shift = (((size_t) address & 3) * 8);
uint32_t shift = ((size_t) address & 3) * 8;
uint32_t res;
uint32_t assumed;
......@@ -42,13 +42,13 @@ struct AtomicIntegerImpl<T, 2> {
template<typename T>
struct AtomicIntegerImpl<T, 4> {
inline __device__ void operator()(T *address, T val) {
uint32_t *address_as_ull = (uint32_t *) (address);
uint32_t old = *address_as_ull;
uint32_t *address_as_ui = (uint32_t *) address;
uint32_t old = *address_as_ui;
uint32_t assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, OP(val, (T) old));
old = atomicCAS(address_as_ui, assumed, OP(val, (T) old));
} while (assumed != old);
}
};
......@@ -56,7 +56,7 @@ struct AtomicIntegerImpl<T, 4> {
template<typename T>
struct AtomicIntegerImpl<T, 8> {
inline __device__ void operator()(T *address, T val) {
unsigned long long *address_as_ull = (unsigned long long *) (address);
unsigned long long *address_as_ull = (unsigned long long *) address;
unsigned long long old = *address_as_ull;
unsigned long long assumed;
......@@ -67,27 +67,12 @@ struct AtomicIntegerImpl<T, 8> {
}
};
static inline __device__ void atomicMax(uint8_t *address, uint8_t val) {
AtomicIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomicMax(int8_t *address, int8_t val) {
AtomicIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomicMax(int16_t *address, int16_t val) {
AtomicIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomicMax(int64_t *address, int64_t val) {
AtomicIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
#ifdef CUDA_HALF_TENSOR
static inline __device__ void atomicMax(half *address, half val) {}
#endif
template <typename T, size_t n>
struct AtomicDecimalImpl;
static inline __device__ void atomicMax(float *address, float val) {
template <>
struct AtomicDecimalImpl<float, 4> {
inline __device__ void operator()(float *address, float val) {
int *address_as_i = (int *) address;
int old = *address_as_i;
int assumed;
......@@ -96,9 +81,12 @@ static inline __device__ void atomicMax(float *address, float val) {
assumed = old;
old = atomicCAS(address_as_i, assumed, __float_as_int(OP(val, __int_as_float(assumed))));
} while (assumed != old);
}
}
};
static inline __device__ void atomicMax(double *address, double val) {
template <>
struct AtomicDecimalImpl<double, 8> {
inline __device__ void operator()(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *) address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
......@@ -107,4 +95,15 @@ static inline __device__ void atomicMax(double *address, double val) {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(OP(val, __longlong_as_double(assumed))));
} while (assumed != old);
}
}
};
static inline __device__ void atomicMax(uint8_t *address, uint8_t val) { AtomicIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
static inline __device__ void atomicMax(int8_t *address, int8_t val) { AtomicIntegerImpl<int8_t, sizeof(int8_t) >()(address, val); }
static inline __device__ void atomicMax(int16_t *address, int16_t val) { AtomicIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
static inline __device__ void atomicMax(int64_t *address, int64_t val) { AtomicIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
static inline __device__ void atomicMax(float *address, float val) { AtomicDecimalImpl<float, sizeof(float) >()(address, val); }
static inline __device__ void atomicMax(double *address, double val) { AtomicDecimalImpl<double, sizeof(double) >()(address, val); }
#ifdef CUDA_HALF_TENSOR
static inline __device__ void atomicMax(half *address, half val) {}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment