Unverified Commit f003f371 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[GQA] Add regional atomic add to slightly boost performance (#1093)

* [Lint]

* [BugFix] Freeze the memory order of all atomic_add operations

* [Lint]

* [Atomic] Move on to regional atomic add

* [Lint]
parent 5cb5c068
......@@ -366,21 +366,21 @@ def flashattn_bwd_atomic_add(batch,
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, d in T.Parallel(block_N, dim_qk):
T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, d],
dq[i, d],
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
bx, :],
dq,
memory_order="release")
for i, d in T.Parallel(block_M, dim_v):
T.atomic_add(
dV[k_start_idx + by * block_M + i, bx // groups, d],
dv[i, d],
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dv,
memory_order="release")
for i, d in T.Parallel(block_M, dim_qk):
T.atomic_add(
dK[k_start_idx + by * block_M + i, bx // groups, d],
dk[i, d],
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dk,
memory_order="release")
return flash_bwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment