"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "95e3b5a7160da6679e9507602f801866c3672e6b"
Commit 9ad9d9cd authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update AtomicAdd functions for BFLOAT16 in common.h (#297)

- Added conditional compilation for BFLOAT16 atomic operations to ensure compatibility with CUDA architectures greater than 7.5.
- Improved code clarity by organizing the AtomicAdd functions and adding relevant comments for better understanding.
parent 5c8de061
...@@ -115,6 +115,8 @@ template <> TL_DEVICE void AtomicAdd(half_t *address, float val) { ...@@ -115,6 +115,8 @@ template <> TL_DEVICE void AtomicAdd(half_t *address, float val) {
atomicAdd(reinterpret_cast<half *>(address), __float2half(val)); atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
} }
// AtomicAdd Functions for BFLOAT16
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16 // AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) { template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
...@@ -126,13 +128,15 @@ template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) { ...@@ -126,13 +128,15 @@ template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val)); atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val));
} }
#endif
// AtomicAdd Functions for FP16x2 // AtomicAdd Functions for FP16x2
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) { TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(address), atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val))); static_cast<half2>(*reinterpret_cast<half2 *>(val)));
} }
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16 // AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) { template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment