Unverified Commit 481cae42 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Example] Revert the atomic/split&sum templates in MHA backward examples (#943)



* revert split+sum template for MHA backward

* lint

* Update example_mha_bwd.py

* Update example_mha_bwd_wgmma_pipelined.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 3aecab8f
...@@ -149,110 +149,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -149,110 +149,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd_atomic_add(batch, def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=128,
num_stages=2):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, accum_dtype), # type: ignore
dV: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=128,
num_stages=2):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -271,9 +168,13 @@ def flashattn_bwd_split(batch, ...@@ -271,9 +168,13 @@ def flashattn_bwd_split(batch,
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype) q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -301,7 +202,7 @@ def flashattn_bwd_split(batch, ...@@ -301,7 +202,7 @@ def flashattn_bwd_split(batch,
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -328,8 +229,7 @@ def flashattn_bwd_split(batch, ...@@ -328,8 +229,7 @@ def flashattn_bwd_split(batch,
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim): for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
...@@ -341,14 +241,13 @@ def flashattn_bwd_split(batch, ...@@ -341,14 +241,13 @@ def flashattn_bwd_split(batch,
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, use_atomic=True): def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64 block_M = 64
block_N = 64 if D_HEAD <= 128 else 32 block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic
return o return o
@staticmethod @staticmethod
...@@ -367,29 +266,14 @@ class _attention(torch.autograd.Function): ...@@ -367,29 +266,14 @@ class _attention(torch.autograd.Function):
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do) delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
if ctx.use_atomic: shape = [BATCH, N_CTX, H, D_HEAD]
kernel = flashattn_bwd_atomic_add( dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2) dk = torch.empty(shape, dtype=torch.float16, device=q.device)
shape = [BATCH, N_CTX, H, D_HEAD] dv = torch.empty(shape, dtype=torch.float16, device=q.device)
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dk = torch.zeros(shape, dtype=torch.float32, device=q.device) dq = kernel_post(dq)
dv = torch.zeros(shape, dtype=torch.float32, device=q.device) return dq, dk, dv, None
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
return dq, dk, dv, None, None
attention = _attention.apply attention = _attention.apply
...@@ -415,9 +299,7 @@ def main( ...@@ -415,9 +299,7 @@ def main(
N_CTX: int = 1024, N_CTX: int = 1024,
D_HEAD: int = 64, D_HEAD: int = 64,
causal: bool = False, causal: bool = False,
use_atomic: bool = True,
): ):
print(f"Test with use_atomic: {use_atomic}")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
...@@ -428,7 +310,7 @@ def main( ...@@ -428,7 +310,7 @@ def main(
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
O = attention(Q, K, V, causal, use_atomic) O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -444,7 +326,6 @@ def main( ...@@ -444,7 +326,6 @@ def main(
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -468,20 +349,6 @@ if __name__ == "__main__": ...@@ -468,20 +349,6 @@ if __name__ == "__main__":
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic)
...@@ -146,121 +146,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -146,121 +146,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd_atomic_add(batch, def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, accum_dtype), # type: ignore
dV: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.wait_wgmma(0)
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -279,9 +165,13 @@ def flashattn_bwd_split(batch, ...@@ -279,9 +165,13 @@ def flashattn_bwd_split(batch,
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype) q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -310,7 +200,7 @@ def flashattn_bwd_split(batch, ...@@ -310,7 +200,7 @@ def flashattn_bwd_split(batch,
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm( T.gemm(
...@@ -348,8 +238,7 @@ def flashattn_bwd_split(batch, ...@@ -348,8 +238,7 @@ def flashattn_bwd_split(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0) T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim): for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
...@@ -361,7 +250,7 @@ def flashattn_bwd_split(batch, ...@@ -361,7 +250,7 @@ def flashattn_bwd_split(batch,
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, use_atomic=True): def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64 block_M = 64
block_N = 64 if D_HEAD <= 128 else 32 block_N = 64 if D_HEAD <= 128 else 32
...@@ -369,7 +258,6 @@ class _attention(torch.autograd.Function): ...@@ -369,7 +258,6 @@ class _attention(torch.autograd.Function):
o, lse = mod(q, k, v) o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic
return o return o
@staticmethod @staticmethod
...@@ -388,29 +276,14 @@ class _attention(torch.autograd.Function): ...@@ -388,29 +276,14 @@ class _attention(torch.autograd.Function):
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = mod_prep(o, do) delta = mod_prep(o, do)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
if ctx.use_atomic: shape = [BATCH, N_CTX, H, D_HEAD]
mod = flashattn_bwd_atomic_add( dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2) dk = torch.empty(shape, dtype=torch.float16, device=q.device)
shape = [BATCH, N_CTX, H, D_HEAD] dv = torch.empty(shape, dtype=torch.float16, device=q.device)
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) mod(q, k, v, do, lse, delta, dq, dk, dv)
dk = torch.zeros(shape, dtype=torch.float32, device=q.device) dq = mod_post(dq)
dv = torch.zeros(shape, dtype=torch.float32, device=q.device) return dq, dk, dv, None
mod(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
mod = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
mod(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
return dq, dk, dv, None, None
attention = _attention.apply attention = _attention.apply
...@@ -436,9 +309,7 @@ def main( ...@@ -436,9 +309,7 @@ def main(
N_CTX: int = 1024, N_CTX: int = 1024,
D_HEAD: int = 64, D_HEAD: int = 64,
causal: bool = False, causal: bool = False,
use_atomic: bool = True,
): ):
print(f"Test with use_atomic: {use_atomic}")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
...@@ -449,7 +320,7 @@ def main( ...@@ -449,7 +320,7 @@ def main(
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
O = attention(Q, K, V, causal, use_atomic) O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -465,7 +336,6 @@ def main( ...@@ -465,7 +336,6 @@ def main(
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -489,20 +359,6 @@ if __name__ == "__main__": ...@@ -489,20 +359,6 @@ if __name__ == "__main__":
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment