"docs/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "d416bc40970040553d7b98127829135a1e1fc22e"
Unverified Commit 2de566e7 authored by Kevinzz's avatar Kevinzz Committed by GitHub
Browse files

[BugFix] Remove memory_order in atomic constexpr and fix NSA bwd (#1260)



* fix nsa bwd and atomic

* [Lint]

* [BugFix]
- New implementation for atomicMax and atomicMin using atomicCAS
- PTX version atomicAdd for single 16-byte data
- Modify the test cases

* [Lint]

---------
Co-authored-by: default avatartzj-fxz <tzjfxz@gmail.com>
parent 729e66ca
...@@ -106,8 +106,8 @@ def tilelang_kernel_fwd( ...@@ -106,8 +106,8 @@ def tilelang_kernel_fwd(
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(G, BS): for k, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
...@@ -124,18 +124,18 @@ def tilelang_kernel_fwd( ...@@ -124,18 +124,18 @@ def tilelang_kernel_fwd(
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True) T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G): for k in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale)
for i, j in T.Parallel(G, BS): for k, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G): for k in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k]
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
# Rescale # Rescale
for i, j in T.Parallel(G, BV): for k, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i] acc_o[k, j] *= scores_scale[k]
# V * softmax(Q * K) # V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
...@@ -465,8 +465,8 @@ def tilelang_kernel_bwd_dqkv( ...@@ -465,8 +465,8 @@ def tilelang_kernel_bwd_dqkv(
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
# [G] # [G]
T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta)
for i, j in T.Parallel(BS, G): for _i, _j in T.Parallel(BS, G):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale
# [BS, G] @ [G, BK] -> [BS, BK] # [BS, G] @ [G, BK] -> [BS, BK]
T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
......
...@@ -46,10 +46,22 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, ...@@ -46,10 +46,22 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> || if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) && std::is_same_v<NT1, __nv_bfloat16>) {
memory_order == int(cuda::memory_order_relaxed)) { // There is no implementation of atomicMax for half and bf16 in cuda.
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)); // We simulate this process by atomicCAS loop.
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)); aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
...@@ -61,11 +73,21 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, ...@@ -61,11 +73,21 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> || if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) && std::is_same_v<NT1, __nv_bfloat16>) {
memory_order == int(cuda::memory_order_relaxed)) { unsigned short *address_as_ushort =
return static_cast<T1>( reinterpret_cast<unsigned short *>(address);
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val))); unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>( return static_cast<T1>(
...@@ -78,10 +100,22 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, ...@@ -78,10 +100,22 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> || if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) && std::is_same_v<NT1, __nv_bfloat16>) {
memory_order == int(cuda::memory_order_relaxed)) { // There is no implementation of atomicMin for half and bf16 in cuda.
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)); // We simulate this process by atomicCAS loop.
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val < *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)); aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
...@@ -93,11 +127,21 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, ...@@ -93,11 +127,21 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> || if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) && std::is_same_v<NT1, __nv_bfloat16>) {
memory_order == int(cuda::memory_order_relaxed)) { unsigned short *address_as_ushort =
return static_cast<T1>( reinterpret_cast<unsigned short *>(address);
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val))); unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val < *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>( return static_cast<T1>(
...@@ -110,10 +154,67 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, ...@@ -110,10 +154,67 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> || if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) && std::is_same_v<NT1, __nv_bfloat16>) {
memory_order == int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)); atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
// Since atomic ref do not support memory order, we need to inline ptx
// code here for each situation
if constexpr (std::is_same_v<NT1, half>) {
// fp16
__half ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
} else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
// bf16
__nv_bfloat16 ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
}
}
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)); aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
...@@ -125,11 +226,69 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, ...@@ -125,11 +226,69 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> || if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) && std::is_same_v<NT1, __nv_bfloat16>) {
memory_order == int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) {
return static_cast<T1>( return static_cast<T1>(
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val))); atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
} else {
if constexpr (std::is_same_v<NT1, half>) {
// fp16
__half ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
return static_cast<T1>(*reinterpret_cast<__half *>(&ret_val_cast));
} else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
// bf16
__nv_bfloat16 ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
return static_cast<T1>(
*reinterpret_cast<__nv_bfloat16 *>(&ret_val_cast));
}
}
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>( return static_cast<T1>(
......
...@@ -236,7 +236,31 @@ def run_atomic_addx2(M, N, block_M, block_N): ...@@ -236,7 +236,31 @@ def run_atomic_addx2(M, N, block_M, block_N):
torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
@tilelang.jit def test_atomic_add():
run_atomic_add(8, 128, 128, 32, 32)
def test_atomic_max():
run_atomic_max(4, 64, 64, 16, 16)
def test_atomic_min():
run_atomic_min(4, 64, 64, 16, 16)
def test_atomic_load_store():
run_atomic_load_store(64, 64, 16, 16)
def test_atomic_memory_order():
run_atomic_memory_order(4, 64, 64, 16, 16)
def test_atomic_addx2():
run_atomic_addx2(32, 64, 8, 16)
@tilelang.jit(debug_root_path="./testing/python/language")
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
...@@ -248,9 +272,9 @@ def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float" ...@@ -248,9 +272,9 @@ def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"
idx_j = by * block_N + j idx_j = by * block_N + j
if idx_i < M and idx_j < N: if idx_i < M and idx_j < N:
val = A[idx_i, idx_j] val = A[idx_i, idx_j]
T.atomic_add(B[idx_i, idx_j], val, memory_order="relaxed") T.atomic_add(B[idx_i, idx_j], val, memory_order="release")
T.atomic_max(C[idx_i, idx_j], val, memory_order="acquire") T.atomic_max(C[idx_i, idx_j], val, memory_order="relaxed")
T.atomic_min(D[idx_i, idx_j], val, memory_order="release") T.atomic_min(D[idx_i, idx_j], val, memory_order="relaxed")
return atomic_different_orders return atomic_different_orders
...@@ -271,30 +295,6 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): ...@@ -271,30 +295,6 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A)) torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A))
def test_atomic_add():
run_atomic_add(8, 128, 128, 32, 32)
def test_atomic_max():
run_atomic_max(4, 64, 64, 16, 16)
def test_atomic_min():
run_atomic_min(4, 64, 64, 16, 16)
def test_atomic_load_store():
run_atomic_load_store(64, 64, 16, 16)
def test_atomic_memory_order():
run_atomic_memory_order(4, 64, 64, 16, 16)
def test_atomic_addx2():
run_atomic_addx2(32, 64, 8, 16)
@tilelang.jit @tilelang.jit
def atomic_addx4_program(M, N, block_M, block_N): def atomic_addx4_program(M, N, block_M, block_N):
...@@ -361,7 +361,9 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"): ...@@ -361,7 +361,9 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"):
def test_atomic_different_memory_orders(): def test_atomic_different_memory_orders():
run_atomic_different_memory_orders(32, 32, 8, 8) run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float")
run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float16")
run_atomic_different_memory_orders(32, 32, 8, 8, dtype="bfloat16")
def test_atomic_addx4(): def test_atomic_addx4():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment