Unverified Commit aa0b1090 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Support atomic add with ret (#870)

* Add atomic operations for CUDA templates in new atomic.h file

- Introduced atomic functions including AtomicMax, AtomicMin, AtomicAdd, and their return variants for various data types.
- Implemented support for half, bfloat16, and float types with appropriate memory ordering.
- Moved atomic-related utilities from common.h to the new atomic.h file for better organization.
- Added Python bindings for atomic operations in tilelang, including atomic_max, atomic_min, atomic_add, and their vectorized counterparts.
- Updated customize.py to utilize the new atomic functions, enhancing modularity and maintainability.

* Refactor atomic operations in CUDA templates for improved readability

- Reformatted atomic operation implementations in atomic.h for better code clarity.
- Adjusted function signatures in tilelang's atomic.py to enhance readability by aligning parameters.
- Cleaned up unnecessary whitespace and comments in customize.py to streamline the codebase.

* Add thread storage synchronization configuration option

- Introduced a new configuration option `tl.disable_thread_storage_sync` to control the automatic insertion of thread synchronization barriers in shared memory access.
- Updated the `ThreadSync` pass to check this configuration and bypass synchronization if disabled.
- Enhanced documentation in `builtin.h` and `pass_config.py` to clarify the purpose and usage of the new option.

* Refactor thread storage sync configuration retrieval

- Simplified the retrieval of the thread storage sync configuration in the `ThreadSync` pass by removing unnecessary intermediate variables.
- Ensured that the inclusion of `builtin.h` is consistent by moving it to the appropriate location in the file.

* test fix

* Update atomic operations and tests for improved functionality

- Updated atomic operations in CUDA templates to remove unnecessary address_of calls, enhancing performance and readability.
- Refactored atomic operation signatures in tilelang's atomic.py to accept references instead of pointers.
- Added new atomic operations and corresponding test cases for atomic add, max, min, and load/store functionalities in the testing suite.
- Updated the TVM subproject to the latest commit for better compatibility.

* Update attention sink examples to use 32 heads

- Modified the `heads` parameter in both `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py` and `example_mha_sink_fwd_bhsd_wgmma_pipelined.py` from 1 to 32 to enhance performance in attention mechanisms.
- Ensured consistency across example scripts for improved usability and testing.

* Refactor atomic add handling in vectorization

- Simplified the extraction of buffer loads for atomic add operations by removing unnecessary address_of calls, improving code clarity and performance.
- Updated the data type retrieval for vectorization size calculation to directly access the buffer load node, enhancing efficiency.

* Add loop break functionality and enhance thread synchronization

- Introduced a new `loop_break` function in `customize.py` to allow breaking out of loops, returning a call to the `tl.loop_break` intrinsic.
- Updated the `sync_threads` function in `builtin.py` to accept optional parameters for `barrier_id` and `arrive_count`, improving its flexibility for thread synchronization.
- Added necessary imports in `__init__.py` to include the new `loop_break` function for broader accessibility.

* test fix
parent 1dfac2e8
...@@ -366,9 +366,9 @@ def gen_inputs(B, H, Sq, Skv, D, ...@@ -366,9 +366,9 @@ def gen_inputs(B, H, Sq, Skv, D,
def main( def main(
batch: int = 1, batch: int = 1,
heads: int = 64, heads: int = 32,
seq_q: int = 4096, seq_q: int = 256,
seq_kv: int = 4096, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: int | None = None, window_size: int | None = None,
......
...@@ -229,10 +229,10 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens ...@@ -229,10 +229,10 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens
return query, key, value, sinks return query, key, value, sinks
def main(batch: int = 8, def main(batch: int = 1,
heads: int = 32, heads: int = 1,
seq_q: int = 4096, seq_q: int = 256,
seq_kv: int = 4096, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: int | None = None,
tune: bool = False): tune: bool = False):
......
...@@ -354,10 +354,10 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens ...@@ -354,10 +354,10 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens
return query, key, value, sinks return query, key, value, sinks
def main(batch: int = 8, def main(batch: int = 1,
heads: int = 32, heads: int = 32,
seq_q: int = 4096, seq_q: int = 256,
seq_kv: int = 4096, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: int | None = None,
tune: bool = False): tune: bool = False):
......
...@@ -293,10 +293,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -293,10 +293,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
if (dst_predicate.defined()) if (dst_predicate.defined())
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype)); dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
Call address_of_value = new_args.push_back(dst_value);
tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});
new_args.push_back(address_of_value);
new_args.push_back(src_value); new_args.push_back(src_value);
Call atomicadd_call = Call atomicadd_call =
......
...@@ -20,6 +20,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool); ...@@ -20,6 +20,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
......
...@@ -55,6 +55,20 @@ static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; ...@@ -55,6 +55,20 @@ static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kDisableDynamicTailSplit = static constexpr const char *kDisableDynamicTailSplit =
"tl.disable_dynamic_tail_split"; "tl.disable_dynamic_tail_split";
/*!
* \brief Whether to disable thread storage synchronization
*
* When enabled, disables the automatic insertion of thread synchronization
* barriers (e.g., __syncthreads()) for shared memory access coordination.
* This can be useful for performance optimization in cases where manual
* synchronization is preferred or when synchronization is not needed.
*
* kDisableThreadStorageSync = "tl.disable_thread_storage_sync"
*
*/
static constexpr const char *kDisableThreadStorageSync =
"tl.disable_thread_storage_sync";
/*! /*!
* \brief The size of the vectorized dimension in buffer, designed by user * \brief The size of the vectorized dimension in buffer, designed by user
* *
......
#pragma once
#ifndef __CUDACC_RTC__
#include <cuda_runtime.h>
#endif
#include <cuda/atomic>
#include <cutlass/numeric_types.h>
using cutlass::bfloat16_t;
using cutlass::half_t;
#define TL_DEVICE __forceinline__ __device__
template <typename T> struct normalize_atomic_type {
using type = T;
};
template <> struct normalize_atomic_type<half_t> {
using type = half;
};
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
template <> struct normalize_atomic_type<bfloat16_t> {
using type = __nv_bfloat16;
};
#endif
template <typename T1, typename T2> TL_DEVICE T1 cuda_cast(T2 val) {
return T1(val);
}
template <> TL_DEVICE half cuda_cast<half, float>(float val) {
return __float2half(val);
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
return __float2bfloat16(val);
}
#endif
template <typename T1, typename T2>
TL_DEVICE void AtomicMax(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
return static_cast<T1>(
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicMin(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
return static_cast<T1>(
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
return static_cast<T1>(
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
}
}
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(ref),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}
TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val) {
return atomicAdd(reinterpret_cast<half2 *>(ref),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val) {
atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(ref),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) {
return atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(ref),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
#endif
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
TL_DEVICE void AtomicAddx2(float *ref, float *val) {
atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}
TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val) {
return atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}
TL_DEVICE void AtomicAddx4(float *ref, float *val) {
atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val) {
return atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
#endif
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
return aref.load(cuda::memory_order(memory_order));
}
template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
}
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif #endif
#include <cuda/atomic> #include "atomic.h"
#include <cutlass/fast_math.h> #include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <math_constants.h> #include <math_constants.h>
...@@ -138,141 +138,6 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) { ...@@ -138,141 +138,6 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
return smem_int; return smem_int;
} }
template <typename T> struct normalize_atomic_type {
using type = T;
};
template <> /**
* Map the public half_t alias to the native `half` type for atomic
* operations.
*
* Used by the atomic utilities to normalize externally exposed
* typedefs (e.g., Cutlass half_t) to the compiler's native `half`
* representation so correct atomic intrinsics or `cuda::atomic_ref`
* specializations can be selected.
*/
struct normalize_atomic_type<half_t> {
using type = half;
};
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
template <> struct normalize_atomic_type<bfloat16_t> {
using type = __nv_bfloat16;
};
#endif
template <typename T1, typename T2> TL_DEVICE T1 cuda_cast(T2 val) {
return T1(val);
}
template <> TL_DEVICE half cuda_cast<half, float>(float val) {
return __float2half(val);
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
return __float2bfloat16(val);
}
#endif
template <typename T1, typename T2>
TL_DEVICE void AtomicMax(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicMin(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
// AtomicAdd Functions for FP16x2
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16x2
TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(address),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
#endif
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
// AtomicAdd Functions for FLOAT16x2
TL_DEVICE void AtomicAddx2(float *address, float *val) {
atomicAdd(reinterpret_cast<float2 *>(address),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}
// AtomicAdd Functions for FLOAT16x4
TL_DEVICE void AtomicAddx4(float *address, float *val) {
atomicAdd(reinterpret_cast<float4 *>(address),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
#endif
template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) {
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*address);
return aref.load(cuda::memory_order(memory_order));
}
template <typename T1, typename T2>
TL_DEVICE /**
* Atomically stores a value into the given address using the
* specified memory ordering.
*
* The value is converted to the normalized atomic storage type for T1
* before being stored (for example, vectorized or reduced-width types
* such as FP16/BF16 are mapped to their underlying hardware
* representation). `memory_order` must be an `int` representation of
* a `cuda::memory_order` value (e.g.,
* `int(cuda::memory_order_relaxed)`).
*
* @param address Pointer to the destination atomic object.
* @param value Value to store; will be cast to the atomic storage
* type.
* @param memory_order Memory ordering for the atomic store (as an
* `int`-cast `cuda::memory_order`).
*/
void
AtomicStore(T1 *address, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
}
// DP4A // DP4A
template <typename InDatatype, typename OutDatatype> template <typename InDatatype, typename OutDatatype>
TL_DEVICE /** TL_DEVICE /**
......
...@@ -54,13 +54,8 @@ private: ...@@ -54,13 +54,8 @@ private:
if (node->op == builtin::call_extern() && node->args.size() >= 2) { if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") { if (func_name->value == "AtomicAdd") {
const CallNode *addr_call = node->args[1].as<CallNode>();
if (addr_call && addr_call->op == builtin::address_of() &&
addr_call->args.size() == 1) {
const BufferLoadNode *buffer_load_dst = const BufferLoadNode *buffer_load_dst =
addr_call->args[0].as<BufferLoadNode>(); node->args[1].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src = const BufferLoadNode *buffer_load_src =
node->args[2].as<BufferLoadNode>(); node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() && if (buffer_load_src && buffer_load_src->buffer.defined() &&
...@@ -76,7 +71,6 @@ private: ...@@ -76,7 +71,6 @@ private:
} }
} }
} }
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
...@@ -219,13 +213,8 @@ private: ...@@ -219,13 +213,8 @@ private:
// bx * stride_x + (i % (stride_x / (tx_extent * // bx * stride_x + (i % (stride_x / (tx_extent *
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ % // vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_] // (stride / vector_size_)) * vector_size_]
const CallNode *addr_call = node->args[1].as<CallNode>();
if (!addr_call || addr_call->op != builtin::address_of() ||
addr_call->args.size() != 1) {
return StmtExprMutator::VisitExpr_(node);
}
const BufferLoadNode *old_dst_node = const BufferLoadNode *old_dst_node =
addr_call->args[0].as<BufferLoadNode>(); node->args[1].as<BufferLoadNode>();
const BufferLoadNode *old_value_node = const BufferLoadNode *old_value_node =
node->args[2].as<BufferLoadNode>(); node->args[2].as<BufferLoadNode>();
if (!old_dst_node || !old_value_node) { if (!old_dst_node || !old_value_node) {
...@@ -339,8 +328,7 @@ For VectorizeAtomicAdd(const For &for_node, const Var &thread_var, ...@@ -339,8 +328,7 @@ For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
if (call->op == builtin::call_extern() && call->args.size() >= 2) { if (call->op == builtin::call_extern() && call->args.size() >= 2) {
const auto *func_name = call->args[0].as<StringImmNode>(); const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name->value == "AtomicAdd") { if (func_name->value == "AtomicAdd") {
DataType dtype = DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
} }
} }
......
...@@ -235,7 +235,8 @@ private: ...@@ -235,7 +235,8 @@ private:
bool IsLocalBuffer(const Buffer &buffer) { bool IsLocalBuffer(const Buffer &buffer) {
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "local" || scope == "local.fragment"; return scope == "local" || scope == "local.fragment" ||
scope == "local.var";
} }
bool isSharedBuffer(const Buffer &buffer) { bool isSharedBuffer(const Buffer &buffer) {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "../op/builtin.h"
#include "./common/thread_sync_types.h" #include "./common/thread_sync_types.h"
#include "./storage_access.h" #include "./storage_access.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -769,6 +770,12 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) { ...@@ -769,6 +770,12 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, const IRModule &m, auto pass_func = [storage_scope](PrimFunc f, const IRModule &m,
const PassContext &ctx) { const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
// Check if thread storage sync is disabled
bool disable_syncthreads =
ctx->GetConfig(kDisableThreadStorageSync, Bool(false)).value()->value;
if (disable_syncthreads) {
return f;
}
return tl::TileLangThreadSync(std::move(f), storage_scope); return tl::TileLangThreadSync(std::move(f), storage_scope);
; ;
}; };
......
...@@ -2,6 +2,7 @@ import tilelang.testing ...@@ -2,6 +2,7 @@ import tilelang.testing
import tilelang.language as T import tilelang.language as T
@tilelang.jit
def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): def atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
...@@ -19,9 +20,7 @@ def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): ...@@ -19,9 +20,7 @@ def atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
program = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype)
kernel = tilelang.compile(program)
# print(kernel.get_kernel_source())
import torch import torch
def ref_program(A, B): def ref_program(A, B):
...@@ -35,12 +34,348 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): ...@@ -35,12 +34,348 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
ref_B = B.clone() ref_B = B.clone()
ref_program(A, ref_B) ref_program(A, ref_B)
kernel(A, B) kernel(A, B)
torch.testing.assert_close(B, ref_B) torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
T.atomic_add(B[bx * block_M, by * block_N], A_shared)
return atomic_add
def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype)
print(kernel.get_kernel_source())
import torch
def ref_program(A, B):
for k in range(K):
for i in range(M):
for j in range(N):
B[i, j] += A[k, i, j]
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone()
ref_program(A, ref_B)
kernel(A, B)
print(B)
print(ref_B)
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_max_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
return atomic_max
def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype)
import torch
def ref_program(A, B):
for k in range(K):
for i in range(M):
for j in range(N):
B[i, j] = max(B[i, j], A[k, i, j])
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone()
ref_program(A, ref_B)
kernel(A, B)
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_min_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
return atomic_min
def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype)
import torch
def ref_program(A, B):
for k in range(K):
for i in range(M):
for j in range(N):
B[i, j] = min(B[i, j], A[k, i, j])
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone()
ref_program(A, ref_B)
kernel(A, B)
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_load_store_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i
idx_j = by * block_N + j
if idx_i < M and idx_j < N:
val = T.atomic_load(A[idx_i, idx_j])
T.atomic_store(B[idx_i, idx_j], val)
return atomic_load_store
def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype)
import torch
A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
kernel(A, B)
torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(
B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed")
return atomic_with_memory_order
def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
kernel = atomic_memory_order_program(K, M, N, block_M, block_N, dtype=dtype)
import torch
def ref_program(A, B):
for k in range(K):
for i in range(M):
for j in range(N):
B[i, j] += A[k, i, j]
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone()
ref_program(A, ref_B)
kernel(A, B)
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_addx2_program(M, N, block_M, block_N):
@T.prim_func
def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N // 2):
idx_i = bx * block_M + i
idx_j = by * block_N + j * 2
T.atomic_addx2(B[idx_i, idx_j], A[idx_i, idx_j])
return atomic_addx2
def run_atomic_addx2(M, N, block_M, block_N):
kernel = atomic_addx2_program(M, N, block_M, block_N)
import torch
A = torch.randn(M, N, dtype=torch.float16).cuda()
B = torch.zeros(M, N, dtype=torch.float16).cuda()
ref_B = B.clone()
for i in range(M):
for j in range(0, N - 1, 2):
ref_B[i, j] += A[i, j]
ref_B[i, j + 1] += A[i, j + 1]
kernel(A, B)
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_different_orders(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor(
(M, N), dtype), D: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i
idx_j = by * block_N + j
if idx_i < M and idx_j < N:
val = A[idx_i, idx_j]
T.atomic_add(B[idx_i, idx_j], val, memory_order="relaxed")
T.atomic_max(C[idx_i, idx_j], val, memory_order="acquire")
T.atomic_min(D[idx_i, idx_j], val, memory_order="release")
return atomic_different_orders
def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
kernel = atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=dtype)
import torch
A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
D = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda()
kernel(A, B, C, D)
torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A))
torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A))
def test_atomic_add(): def test_atomic_add():
run_atomic_add(8, 128, 128, 32, 32) run_atomic_add(8, 128, 128, 32, 32)
def test_atomic_max():
run_atomic_max(4, 64, 64, 16, 16)
def test_atomic_min():
run_atomic_min(4, 64, 64, 16, 16)
def test_atomic_load_store():
run_atomic_load_store(64, 64, 16, 16)
def test_atomic_memory_order():
run_atomic_memory_order(4, 64, 64, 16, 16)
def test_atomic_addx2():
run_atomic_addx2(32, 64, 8, 16)
@tilelang.jit
def atomic_addx4_program(M, N, block_M, block_N):
@T.prim_func
def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N // 4):
idx_i = bx * block_M + i
idx_j = by * block_N + j * 4
T.atomic_addx4(B[idx_i, idx_j], A[idx_i, idx_j])
return atomic_addx4
def run_atomic_addx4(M, N, block_M, block_N):
kernel = atomic_addx4_program(M, N, block_M, block_N)
import torch
A = torch.randn(M, N, dtype=torch.float32).cuda()
B = torch.zeros(M, N, dtype=torch.float32).cuda()
ref_B = B.clone()
for i in range(M):
for j in range(0, N - 3, 4):
ref_B[i, j] += A[i, j]
ref_B[i, j + 1] += A[i, j + 1]
ref_B[i, j + 2] += A[i, j + 2]
ref_B[i, j + 3] += A[i, j + 3]
kernel(A, B)
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit
def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
old_vals: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i
idx_j = by * block_N + j
if idx_i < M and idx_j < N:
old_vals[idx_i, idx_j] = T.atomic_add(
B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True)
return atomic_with_return_prev
def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"):
kernel = atomic_return_prev_program(M, N, block_M, block_N, dtype=dtype)
import torch
A = torch.ones(M, N, dtype=getattr(torch, dtype)).cuda() * 5.0
B = torch.ones(M, N, dtype=getattr(torch, dtype)).cuda() * 2.0
old_vals = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
initial_B = B.clone()
kernel(A, B, old_vals)
torch.testing.assert_close(old_vals, initial_B, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(B, initial_B + A, atol=1e-3, rtol=1e-3)
def test_atomic_different_memory_orders():
run_atomic_different_memory_orders(32, 32, 8, 8)
def test_atomic_addx4():
run_atomic_addx4(16, 64, 4, 4)
def test_atomic_return_prev():
run_atomic_return_prev(32, 32, 8, 8)
# TODO(lei): test failed and this is experimental
# CC @dyq
# def test_tile_atomic_add():
# run_tile_atomic_add(8, 128, 128, 32, 32)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -70,6 +70,7 @@ from .customize import ( ...@@ -70,6 +70,7 @@ from .customize import (
view, # noqa: F401 view, # noqa: F401
atomic_load, # noqa: F401 atomic_load, # noqa: F401
atomic_store, # noqa: F401 atomic_store, # noqa: F401
loop_break, # noqa: F401
) )
from .logical import any_of, all_of # noqa: F401 from .logical import any_of, all_of # noqa: F401
from .builtin import * # noqa: F401 from .builtin import * # noqa: F401
......
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
import tilelang.language as T
from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from typing import Optional
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
"consume": 1,
"acquire": 2,
"release": 3,
"acq_rel": 4,
"seq_cst": 5,
}
def atomic_max(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
return_prev: bool = False) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic max.
value (PrimExpr): Value to compare/store atomically.
memory_order (Optional[str]): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
If provided, it is translated to the corresponding numeric memory-order id before the call.
return_prev (bool): If True, return the previous value; if False, return handle (default False).
Returns:
PrimExpr: A handle/expression representing the issued atomic maximum operation, or the previous value if return_prev is True.
Examples:
>>> # Basic atomic max operation
>>> counter = T.Tensor([1], "float32", name="counter")
>>> atomic_max(counter, 42.0)
>>> # With memory ordering
>>> atomic_max(counter, 100.0, memory_order="acquire")
>>> # Get the previous value
>>> prev_value = atomic_max(counter, 50.0, return_prev=True)
>>> # prev_value now contains the value that was in counter before the max operation
>>> # Use in parallel reduction to find global maximum
>>> @T.prim_func
>>> def find_max(data: T.Buffer, result: T.Buffer):
>>> for i in T.thread_binding(128, "threadIdx.x"):
>>> atomic_max(result, data[i])
"""
func_name = "AtomicMaxRet" if return_prev else "AtomicMax"
return_type = dst.dtype if return_prev else "handle"
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
else:
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
return_prev: bool = False) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic min.
value (PrimExpr): Value to compare/store atomically.
memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering.
return_prev (bool): If True, return the previous value; if False, return handle (default False).
Returns:
PrimExpr: A handle expression representing the atomic-min operation, or the previous value if return_prev is True.
Examples:
>>> # Basic atomic min operation
>>> min_val = T.Tensor([1], "int32", name="min_val")
>>> atomic_min(min_val, 10)
>>> # Find minimum across threads
>>> @T.prim_func
>>> def find_min(data: T.Buffer, result: T.Buffer):
>>> for i in T.thread_binding(256, "threadIdx.x"):
>>> atomic_min(result, data[i])
>>> # Track minimum with previous value
>>> threshold = T.Tensor([1], "float32", name="threshold")
>>> old_min = atomic_min(threshold, 3.14, return_prev=True)
>>> # old_min contains the previous minimum value
>>> # With relaxed memory ordering for performance
>>> atomic_min(min_val, 5, memory_order="relaxed")
"""
func_name = "AtomicMinRet" if return_prev else "AtomicMin"
return_type = dst.dtype if return_prev else "handle"
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
else:
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
return_prev: bool = False) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic add.
value (PrimExpr): Value to add atomically.
memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering.
return_prev (bool): If True, return the previous value; if False, return handle (default False).
Returns:
PrimExpr: A handle representing the atomic addition operation, or the previous value if return_prev is True.
Examples:
>>> # Basic atomic addition
>>> counter = T.Tensor([1], "int32", name="counter")
>>> atomic_add(counter, 1) # Increment counter by 1
>>> # Parallel sum reduction
>>> @T.prim_func
>>> def parallel_sum(data: T.Buffer, result: T.Buffer):
>>> for i in T.thread_binding(1024, "threadIdx.x"):
>>> atomic_add(result, data[i])
>>> # Get previous value for debugging
>>> old_value = atomic_add(counter, 5, return_prev=True)
>>> # old_value contains the value before adding 5
>>> # Tensor-to-tensor atomic add (tile-region based)
>>> src_tensor = T.Tensor([128, 64], "float32", name="src")
>>> dst_tensor = T.Tensor([128, 64], "float32", name="dst")
>>> atomic_add(dst_tensor, src_tensor) # Add entire tensors atomically
>>> # With memory ordering for scalar operations
>>> atomic_add(counter, 10, memory_order="acquire")
>>> # Accumulate gradients in training
>>> gradients = T.Tensor([1000], "float32", name="gradients")
>>> global_grad = T.Tensor([1000], "float32", name="global_grad")
>>> atomic_add(global_grad, gradients)
"""
def get_extent(data):
"""
Return the inferred extent (shape) of a buffer-like object.
If `data` is a Var bound to a let value, the let value is resolved before inspection.
Parameters:
data: A Var, Buffer, or BufferRegion to inspect.
Returns:
The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined.
"""
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, Buffer):
return data.shape
elif isinstance(data, BufferRegion):
return [x.extent for x in data.region]
else:
return None
src_extent = get_extent(value)
dst_extent = get_extent(dst)
if dst_extent is None and src_extent is None:
func_name = "AtomicAddRet" if return_prev else "AtomicAdd"
return_type = dst.dtype if return_prev else "handle"
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
else:
return T.call_extern(return_type, func_name, dst, value,
_MEMORY_ORDER_ID_MAP[memory_order])
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
def _to_region(data, access_type):
from .customize import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, Buffer):
return buffer_to_tile_region(data, access_type)
elif isinstance(data, BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
value = _to_region(value, "r")
dst = _to_region(dst, "w")
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
if return_prev:
raise NotImplementedError(
"return_prev is not supported for tile-region-based atomic operations")
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
"""Perform an atomic addition operation with double-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (double-width)
return_prev (bool): If True, return the previous value; if False, return handle (default False)
Returns:
PrimExpr: Handle to the double-width atomic addition operation, or the previous value if return_prev is True
Examples:
>>> # Atomic addition with FP16 pairs
>>> half_dst = T.Tensor([2], "float16", name="half_dst")
>>> half_val = T.Tensor([2], "float16", name="half_val")
>>> atomic_addx2(half_dst, half_val)
>>> # BF16 vectorized atomic add (requires CUDA Arch > 750)
>>> bf16_dst = T.Tensor([2], "bfloat16", name="bf16_dst")
>>> bf16_val = T.Tensor([2], "bfloat16", name="bf16_val")
>>> atomic_addx2(bf16_dst, bf16_val)
>>> # Get previous paired values
>>> prev_values = atomic_addx2(half_dst, half_val, return_prev=True)
>>> # prev_values is a half2 containing the two previous FP16 values
>>> # Efficient gradient accumulation for mixed precision training
>>> @T.prim_func
>>> def accumulate_fp16_gradients(grads: T.Buffer, global_grads: T.Buffer):
>>> for i in T.thread_binding(128, "threadIdx.x"):
>>> for j in range(0, grads.shape[1], 2): # Process in pairs
>>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2])
"""
func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2"
return_type = dst.dtype if return_prev else "handle"
return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value))
def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
"""Perform an atomic addition operation with quad-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (quad-width)
return_prev (bool): If True, return the previous value; if False, return handle (default False)
Returns:
PrimExpr: Handle to the quad-width atomic addition operation, or the previous value if return_prev is True
Examples:
>>> # Atomic addition with float4 (requires CUDA Arch >= 900)
>>> float4_dst = T.Tensor([4], "float32", name="float4_dst")
>>> float4_val = T.Tensor([4], "float32", name="float4_val")
>>> atomic_addx4(float4_dst, float4_val)
>>> # Get previous float4 values
>>> prev_float4 = atomic_addx4(float4_dst, float4_val, return_prev=True)
>>> # prev_float4 is a float4 containing the four previous float32 values
>>> # High-throughput gradient accumulation for large models
>>> @T.prim_func
>>> def accumulate_float4_gradients(grads: T.Buffer, global_grads: T.Buffer):
>>> for i in T.thread_binding(256, "threadIdx.x"):
>>> for j in range(0, grads.shape[1], 4): # Process 4 floats at once
>>> atomic_addx4(global_grads[i, j:j+4], grads[i, j:j+4])
>>> # Efficient RGBA pixel blending
>>> rgba_dst = T.Tensor([4], "float32", name="rgba_dst") # R, G, B, A channels
>>> rgba_add = T.Tensor([4], "float32", name="rgba_add")
>>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels
"""
func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4"
return_type = "float4" if "float" in str(dst.dtype).lower() else "handle"
return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value))
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
"""
Load a value from the given buffer using the specified atomic memory ordering.
Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
"release", "acq_rel", or "seq_cst" (default).
Raises KeyError if an unknown memory_order is provided.
Note: atomic_load always returns the loaded value, so no return_prev parameter is needed.
Examples:
>>> # Basic atomic load
>>> shared_var = T.Tensor([1], "int32", name="shared_var")
>>> value = atomic_load(shared_var)
>>> # Load with specific memory ordering
>>> value = atomic_load(shared_var, memory_order="acquire")
>>> # Ensures all subsequent memory operations happen after this load
>>> # Relaxed load for performance-critical code
>>> value = atomic_load(shared_var, memory_order="relaxed")
>>> # Producer-consumer pattern
>>> @T.prim_func
>>> def consumer(flag: T.Buffer, data: T.Buffer, result: T.Buffer):
>>> # Wait until producer sets flag
>>> while atomic_load(flag, memory_order="acquire") == 0:
>>> pass # Spin wait
>>> # Now safely read data
>>> result[0] = data[0]
>>> # Load counter for statistics
>>> counter = T.Tensor([1], "int64", name="counter")
>>> current_count = atomic_load(counter, memory_order="relaxed")
"""
return T.call_extern(src.dtype, "AtomicLoad", src, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
"""
Perform an atomic store of `src` into `dst` with the given memory ordering.
Parameters:
dst (Buffer): Destination buffer to store into.
src (PrimExpr): Value to store.
memory_order (str, optional): Memory ordering name; one of "relaxed", "consume",
"acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
The name is mapped to an internal numeric ID used by the underlying runtime.
Returns:
PrimExpr: A handle representing the issued atomic store operation.
Raises:
KeyError: If `memory_order` is not one of the supported names.
Note: atomic_store doesn't return a previous value, so no return_prev parameter is needed.
Examples:
>>> # Basic atomic store
>>> shared_var = T.Tensor([1], "int32", name="shared_var")
>>> atomic_store(shared_var, 42)
>>> # Store with release ordering to publish data
>>> data = T.Tensor([1000], "float32", name="data")
>>> ready_flag = T.Tensor([1], "int32", name="ready_flag")
>>> # ... fill data ...
>>> atomic_store(ready_flag, 1, memory_order="release")
>>> # Ensures all previous writes are visible before flag is set
>>> # Relaxed store for performance
>>> atomic_store(shared_var, 100, memory_order="relaxed")
>>> # Producer-consumer synchronization
>>> @T.prim_func
>>> def producer(data: T.Buffer, flag: T.Buffer):
>>> data[0] = 3.14159 # Write data first
>>> atomic_store(flag, 1, memory_order="release")
>>> # Consumer can now safely read data after seeing flag == 1
>>> # Update configuration atomically
>>> config = T.Tensor([1], "int32", name="config")
>>> new_config = 0x12345678
>>> atomic_store(config, new_config, memory_order="seq_cst")
>>> # Thread-safe logging counter
>>> log_counter = T.Tensor([1], "int64", name="log_counter")
>>> atomic_store(log_counter, 0) # Reset counter atomically
"""
return T.call_extern("handle", "AtomicStore", dst, src, _MEMORY_ORDER_ID_MAP[memory_order])
...@@ -330,10 +330,15 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, ...@@ -330,10 +330,15 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
def sync_threads(): def sync_threads(barrier_id: int = None, arrive_count: int = None):
"""Synchronize all threads in a block. """Synchronize all threads in a block.
""" """
return tir.op.tvm_storage_sync("shared") args = []
if barrier_id is not None:
args.append(barrier_id)
if arrive_count is not None:
args.append(arrive_count)
return tir.call_intrin("int32", "tir.tvm_storage_sync", "shared", *args)
def sync_global(): def sync_global():
......
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs.""" """The language interface for tl programs."""
import tilelang.language as T import tilelang.language as T
from tvm import ir from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op from typing import List, Union
from typing import List, Union, Optional from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
"consume": 1,
"acquire": 2,
"release": 3,
"acq_rel": 4,
"seq_cst": 5,
}
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
...@@ -104,138 +93,6 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, ...@@ -104,138 +93,6 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic max.
value (PrimExpr): Value to compare/store atomically.
memory_order (Optional[str]): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
If provided, it is translated to the corresponding numeric memory-order id before the call.
Returns:
PrimExpr: A handle/expression representing the issued atomic maximum operation.
"""
if memory_order is None:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
Parameters:
memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering.
Returns:
PrimExpr: A handle expression representing the atomic-min operation.
"""
if memory_order is None:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
Returns:
PrimExpr: A handle representing the atomic addition operation.
"""
def get_extent(data):
"""
Return the inferred extent (shape) of a buffer-like object.
If `data` is a Var bound to a let value, the let value is resolved before inspection.
Parameters:
data: A Var, Buffer, or BufferRegion to inspect.
Returns:
The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined.
"""
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, Buffer):
return data.shape
elif isinstance(data, BufferRegion):
return [x.extent for x in data.region]
else:
return None
src_extent = get_extent(value)
dst_extent = get_extent(dst)
if dst_extent is None and src_extent is None:
if memory_order is None:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
else:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value,
_MEMORY_ORDER_ID_MAP[memory_order])
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
def _to_region(data, access_type):
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, Buffer):
return buffer_to_tile_region(data, access_type)
elif isinstance(data, BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
value = _to_region(value, "r")
dst = _to_region(dst, "w")
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation with double-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (double-width)
Returns:
PrimExpr: Handle to the double-width atomic addition operation
"""
return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
def atomic_addx4(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation with quad-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (quad-width)
Returns:
PrimExpr: Handle to the quad-width atomic addition operation
"""
return T.call_extern("handle", "AtomicAddx4", T.address_of(dst), T.address_of(value))
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
"""Perform a 4-element dot product with accumulation (DP4A). """Perform a 4-element dot product with accumulation (DP4A).
...@@ -294,35 +151,10 @@ def view(src: Buffer, ...@@ -294,35 +151,10 @@ def view(src: Buffer,
return T.Tensor(shape, dtype, src.data) return T.Tensor(shape, dtype, src.data)
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: def loop_break():
""" """Break out of the current loop.
Load a value from the given buffer using the specified atomic memory ordering.
Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
"release", "acq_rel", or "seq_cst" (default).
Raises KeyError if an unknown memory_order is provided.
"""
return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src),
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
"""
Perform an atomic store of `src` into `dst` with the given memory ordering.
Parameters:
dst (Buffer): Destination buffer to store into.
src (PrimExpr): Value to store.
memory_order (str, optional): Memory ordering name; one of "relaxed", "consume",
"acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
The name is mapped to an internal numeric ID used by the underlying runtime.
Returns: Returns:
PrimExpr: A handle representing the issued atomic store operation. tir.Call: A call to the `tl.loop_break` intrinsic.
Raises:
KeyError: If `memory_order` is not one of the supported names.
""" """
return T.call_extern("handle", "AtomicStore", T.address_of(dst), src, return T.call_intrin("handle", op.Op.get("tl.loop_break"))
_MEMORY_ORDER_ID_MAP[memory_order])
...@@ -54,6 +54,13 @@ class PassConfigKey(str, Enum): ...@@ -54,6 +54,13 @@ class PassConfigKey(str, Enum):
TL_DISABLE_SHUFFLE_ELECT = "tl.disable_shuffle_elect" TL_DISABLE_SHUFFLE_ELECT = "tl.disable_shuffle_elect"
"""Disable shuffle election optimization. Default: False""" """Disable shuffle election optimization. Default: False"""
TL_DISABLE_THREAD_STORAGE_SYNC = "tl.disable_thread_storage_sync"
"""Disable thread storage synchronization pass. When enabled, disables the
automatic insertion of thread synchronization barriers (e.g., __syncthreads())
for shared memory access coordination. This can be useful for performance
optimization in cases where manual synchronization is preferred or when
synchronization is not needed. Default: False"""
# TIR related configs # TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment