Commit 9ad9d9cd authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update AtomicAdd functions for BFLOAT16 in common.h (#297)

- Added conditional compilation for BFLOAT16 atomic operations to ensure compatibility with CUDA architectures greater than 7.5.
- Improved code clarity by organizing the AtomicAdd functions and adding relevant comments for better understanding.
parent 5c8de061
......@@ -115,6 +115,8 @@ template <> TL_DEVICE void AtomicAdd(half_t *address, float val) {
atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
}
// AtomicAdd Functions for BFLOAT16
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
......@@ -126,13 +128,15 @@ template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val));
}
#endif
// AtomicAdd Functions for FP16x2
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment