#pragma once #ifndef __CUDACC_RTC__ #include #endif #include #include using cutlass::bfloat16_t; using cutlass::half_t; #define TL_DEVICE __forceinline__ __device__ template struct normalize_atomic_type { using type = T; }; template <> struct normalize_atomic_type { using type = half; }; #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) template <> struct normalize_atomic_type { using type = __nv_bfloat16; }; #endif template TL_DEVICE T1 cuda_cast(T2 val) { return T1(val); } template <> TL_DEVICE half cuda_cast(float val) { return __float2half(val); } #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { return __float2bfloat16(val); } #endif template TL_DEVICE void AtomicMax(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr (std::is_same_v || std::is_same_v) { atomicMax(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order)); } } template TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr (std::is_same_v || std::is_same_v) { return static_cast( atomicMax(reinterpret_cast(address), static_cast(val))); } else { cuda::atomic_ref aref(*address); return static_cast( aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order))); } } template TL_DEVICE void AtomicMin(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr (std::is_same_v || std::is_same_v) { atomicMin(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); } } template TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr (std::is_same_v || std::is_same_v) { return static_cast( atomicMin(reinterpret_cast(address), static_cast(val))); } else { cuda::atomic_ref aref(*address); return static_cast( aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); } } template TL_DEVICE void AtomicAdd(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr ((std::is_same_v || std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { atomicAdd(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order)); } } template TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr ((std::is_same_v || std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { return static_cast( atomicAdd(reinterpret_cast(address), static_cast(val))); } else { cuda::atomic_ref aref(*address); return static_cast( aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order))); } } // TODO add memory_order for vectorized atomic add TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, int memory_order = int(cuda::memory_order_relaxed)) { atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val, int memory_order = int(cuda::memory_order_relaxed)) { return atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, int memory_order = int(cuda::memory_order_relaxed)) { atomicAdd( reinterpret_cast<__nv_bfloat162 *>(ref), static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); } TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, int memory_order = int(cuda::memory_order_relaxed)) { return atomicAdd( reinterpret_cast<__nv_bfloat162 *>(ref), static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); } #endif #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) TL_DEVICE void AtomicAddx2(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { return atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } TL_DEVICE void AtomicAddx4(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { return atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } #endif template TL_DEVICE T AtomicLoad(T &ref, int memory_order) { cuda::atomic_ref aref(ref); return aref.load(cuda::memory_order(memory_order)); } template TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { using NT1 = typename normalize_atomic_type::type; cuda::atomic_ref aref(ref); aref.store(cuda_cast(value), cuda::memory_order(memory_order)); }