Commit 46798f25 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Add atomicAdd for FLOAT16x2 and FLOAT16x4 (#522)

* [Enhancement] Add atomic addition functions for FLOAT16x2 and FLOAT16x4 in CUDA

* Introduced `AtomicAddx2` and `AtomicAddx4` functions for performing atomic addition operations on double-width float types in CUDA.
* Updated `customize.py` to include the new `atomic_addx4` function for external calls.
* Modified `__init__.py` to export the new atomic addition function, ensuring accessibility in the module.

* lint fix
parent 6addc509
......@@ -179,6 +179,19 @@ TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
}
#endif
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
// AtomicAdd Functions for FLOAT16x2
TL_DEVICE void AtomicAddx2(float *address, float *val) {
atomicAdd(reinterpret_cast<float2 *>(address),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}
// AtomicAdd Functions for FLOAT16x4
TL_DEVICE void AtomicAddx4(float *address, float *val) {
atomicAdd(reinterpret_cast<float4 *>(address),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
#endif
// DP4A
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
......
......@@ -55,6 +55,7 @@ from .print import print # noqa: F401
from .customize import (
atomic_add, # noqa: F401
atomic_addx2, # noqa: F401
atomic_addx4, # noqa: F401
dp4a, # noqa: F401
clamp, # noqa: F401
reshape, # noqa: F401
......
......@@ -31,6 +31,19 @@ def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
def atomic_addx4(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation with double-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (double-width)
Returns:
PrimExpr: Handle to the double-width atomic addition operation
"""
return T.call_extern("handle", "AtomicAddx4", T.address_of(dst), T.address_of(value))
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
"""Perform a 4-element dot product with accumulation (DP4A).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment