Unverified Commit f003f371 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[GQA] Add regional atomic add to slightly boost performance (#1093)

* [Lint]

* [BugFix] Freeze the memory order of all atomic_add operations

* [Lint]

* [Atomic] Move on to regional atomic add

* [Lint]
parent 5cb5c068
...@@ -366,23 +366,23 @@ def flashattn_bwd_atomic_add(batch, ...@@ -366,23 +366,23 @@ def flashattn_bwd_atomic_add(batch,
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, d in T.Parallel(block_N, dim_qk):
T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, d],
dq[i, d],
memory_order="release")
for i, d in T.Parallel(block_M, dim_v):
T.atomic_add( T.atomic_add(
dV[k_start_idx + by * block_M + i, bx // groups, d], dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
dv[i, d], bx, :],
memory_order="release") dq,
for i, d in T.Parallel(block_M, dim_qk):
T.atomic_add(
dK[k_start_idx + by * block_M + i, bx // groups, d],
dk[i, d],
memory_order="release") memory_order="release")
T.atomic_add(
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dv,
memory_order="release")
T.atomic_add(
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dk,
memory_order="release")
return flash_bwd return flash_bwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment