Unverified Commit 557589ff authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Introduce split+sum template, and optimize `atomic_add` performance...

[Example] Introduce split+sum template, and optimize `atomic_add` performance for bwd examples (#940)

* example fix

* lint fix

* bug fix

* reduce test size.
parent 95170ab7
...@@ -147,7 +147,118 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): ...@@ -147,7 +147,118 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5 sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -171,7 +282,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -171,7 +282,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
dK: T.Tensor(dk_shape, dtype), # type: ignore dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype) q = T.alloc_shared([block_N, dim_qk], dtype)
...@@ -202,10 +313,13 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -202,10 +313,13 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=1): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
...@@ -213,9 +327,6 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -213,9 +327,6 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0) 0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
...@@ -244,7 +355,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -244,7 +355,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, groups=1): def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1] D_HEAD_V = v.shape[-1]
block_M = 128 block_M = 128
...@@ -253,6 +364,7 @@ class _attention(torch.autograd.Function): ...@@ -253,6 +364,7 @@ class _attention(torch.autograd.Function):
o, lse = mod(q, k, v) o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic
return o return o
@staticmethod @staticmethod
...@@ -268,13 +380,48 @@ class _attention(torch.autograd.Function): ...@@ -268,13 +380,48 @@ class _attention(torch.autograd.Function):
return x return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 64 block_M = 128
block_N = 32 block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
delta = mod_prep(o, do) delta = mod_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
groups) if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
...@@ -284,7 +431,8 @@ class _attention(torch.autograd.Function): ...@@ -284,7 +431,8 @@ class _attention(torch.autograd.Function):
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq) dq = mod_post(dq)
dk, dv = dk.sum(0), dv.sum(0) dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None
return dq, dk, dv, None, None, None
attention = _attention.apply attention = _attention.apply
...@@ -321,7 +469,8 @@ def main(BATCH: int = 1, ...@@ -321,7 +469,8 @@ def main(BATCH: int = 1,
D_HEAD_QK: int = 192, D_HEAD_QK: int = 192,
D_HEAD_V: int = 128, D_HEAD_V: int = 128,
groups: int = 16, groups: int = 16,
causal: bool = False): causal: bool = False,
use_atomic: bool = True):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v total_flops = 3 * flops_per_qk + 2 * flops_per_v
...@@ -341,7 +490,7 @@ def main(BATCH: int = 1, ...@@ -341,7 +490,7 @@ def main(BATCH: int = 1,
dO = ( dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_()) device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups) O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -382,7 +531,22 @@ if __name__ == "__main__": ...@@ -382,7 +531,22 @@ if __name__ == "__main__":
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
...@@ -147,7 +147,129 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): ...@@ -147,7 +147,129 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.wait_wgmma(0)
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
for i, j in T.Parallel(block_M, dim_qk):
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5 sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -171,7 +293,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -171,7 +293,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
dK: T.Tensor(dk_shape, dtype), # type: ignore dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype) q = T.alloc_shared([block_N, dim_qk], dtype)
...@@ -202,7 +324,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -202,7 +324,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm( T.gemm(
...@@ -255,7 +377,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -255,7 +377,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, groups=1): def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1] D_HEAD_V = v.shape[-1]
block_M = 128 block_M = 128
...@@ -264,6 +386,7 @@ class _attention(torch.autograd.Function): ...@@ -264,6 +386,7 @@ class _attention(torch.autograd.Function):
o, lse = mod(q, k, v) o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic
return o return o
@staticmethod @staticmethod
...@@ -284,8 +407,43 @@ class _attention(torch.autograd.Function): ...@@ -284,8 +407,43 @@ class _attention(torch.autograd.Function):
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
delta = mod_prep(o, do) delta = mod_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
groups) if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
...@@ -295,7 +453,8 @@ class _attention(torch.autograd.Function): ...@@ -295,7 +453,8 @@ class _attention(torch.autograd.Function):
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq) dq = mod_post(dq)
dk, dv = dk.sum(0), dv.sum(0) dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None
return dq, dk, dv, None, None, None
attention = _attention.apply attention = _attention.apply
...@@ -332,7 +491,8 @@ def main(BATCH: int = 1, ...@@ -332,7 +491,8 @@ def main(BATCH: int = 1,
D_HEAD_QK: int = 192, D_HEAD_QK: int = 192,
D_HEAD_V: int = 128, D_HEAD_V: int = 128,
groups: int = 16, groups: int = 16,
causal: bool = False): causal: bool = False,
use_atomic: bool = True):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v total_flops = 3 * flops_per_qk + 2 * flops_per_v
...@@ -352,7 +512,7 @@ def main(BATCH: int = 1, ...@@ -352,7 +512,7 @@ def main(BATCH: int = 1,
dO = ( dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_()) device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups) O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -393,7 +553,22 @@ if __name__ == "__main__": ...@@ -393,7 +553,22 @@ if __name__ == "__main__":
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
...@@ -149,7 +149,110 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -149,7 +149,110 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=128,
num_stages=2):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, accum_dtype), # type: ignore
dV: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=128,
num_stages=2):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -168,13 +271,9 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -168,13 +271,9 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype) q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -202,7 +301,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -202,7 +301,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -242,13 +341,14 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -242,13 +341,14 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal): def forward(ctx, q, k, v, causal, use_atomic=True):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64 block_M = 64
block_N = 64 if D_HEAD <= 128 else 32 block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic
return o return o
@staticmethod @staticmethod
...@@ -267,14 +367,29 @@ class _attention(torch.autograd.Function): ...@@ -267,14 +367,29 @@ class _attention(torch.autograd.Function):
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do) delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD] shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device) dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device) dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq) dq = kernel_post(dq)
return dq, dk, dv, None
return dq, dk, dv, None, None
attention = _attention.apply attention = _attention.apply
...@@ -300,7 +415,9 @@ def main( ...@@ -300,7 +415,9 @@ def main(
N_CTX: int = 1024, N_CTX: int = 1024,
D_HEAD: int = 64, D_HEAD: int = 64,
causal: bool = False, causal: bool = False,
use_atomic: bool = True,
): ):
print(f"Test with use_atomic: {use_atomic}")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
...@@ -311,7 +428,7 @@ def main( ...@@ -311,7 +428,7 @@ def main(
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
O = attention(Q, K, V, causal) O = attention(Q, K, V, causal, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -327,6 +444,7 @@ def main( ...@@ -327,6 +444,7 @@ def main(
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -350,6 +468,20 @@ if __name__ == "__main__": ...@@ -350,6 +468,20 @@ if __name__ == "__main__":
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic)
...@@ -146,7 +146,121 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -146,7 +146,121 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@tilelang.jit(pass_configs={ @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, accum_dtype), # type: ignore
dV: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.wait_wgmma(0)
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -165,13 +279,9 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -165,13 +279,9 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype) q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -200,7 +310,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -200,7 +310,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm( T.gemm(
...@@ -251,7 +361,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -251,7 +361,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal): def forward(ctx, q, k, v, causal, use_atomic=True):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64 block_M = 64
block_N = 64 if D_HEAD <= 128 else 32 block_N = 64 if D_HEAD <= 128 else 32
...@@ -259,6 +369,7 @@ class _attention(torch.autograd.Function): ...@@ -259,6 +369,7 @@ class _attention(torch.autograd.Function):
o, lse = mod(q, k, v) o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic
return o return o
@staticmethod @staticmethod
...@@ -277,14 +388,29 @@ class _attention(torch.autograd.Function): ...@@ -277,14 +388,29 @@ class _attention(torch.autograd.Function):
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = mod_prep(o, do) delta = mod_prep(o, do)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
if ctx.use_atomic:
mod = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape, dtype=torch.float32, device=q.device)
mod(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
mod = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD] shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device) dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device) dv = torch.empty(shape, dtype=torch.float16, device=q.device)
mod(q, k, v, do, lse, delta, dq, dk, dv) mod(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq) dq = mod_post(dq)
return dq, dk, dv, None
return dq, dk, dv, None, None
attention = _attention.apply attention = _attention.apply
...@@ -310,7 +436,9 @@ def main( ...@@ -310,7 +436,9 @@ def main(
N_CTX: int = 1024, N_CTX: int = 1024,
D_HEAD: int = 64, D_HEAD: int = 64,
causal: bool = False, causal: bool = False,
use_atomic: bool = True,
): ):
print(f"Test with use_atomic: {use_atomic}")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
...@@ -321,7 +449,7 @@ def main( ...@@ -321,7 +449,7 @@ def main(
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
O = attention(Q, K, V, causal) O = attention(Q, K, V, causal, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -337,6 +465,7 @@ def main( ...@@ -337,6 +465,7 @@ def main(
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -360,6 +489,20 @@ if __name__ == "__main__": ...@@ -360,6 +489,20 @@ if __name__ == "__main__":
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic)
...@@ -382,7 +382,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -382,7 +382,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
return out return out
def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64): def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
......
...@@ -286,7 +286,8 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition( ...@@ -286,7 +286,8 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
} }
ICHECK(m_warp * n_warp == num_warps) ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps"; << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
<< ", n_warp: " << n_warp << ", num_warps: " << num_warps;
// Store the computed values in the object's member variables // Store the computed values in the object's member variables
this->m_warp = m_warp; this->m_warp = m_warp;
...@@ -370,6 +371,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition( ...@@ -370,6 +371,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
} else { } else {
ICHECK(0) << "Unknown GemmWarpPolicy"; ICHECK(0) << "Unknown GemmWarpPolicy";
} }
ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps, m_warp: " << m_warp
<< ", n_warp: " << n_warp << ", num_warps: " << num_warps;
// Store the computed values in the object's member variables // Store the computed values in the object's member variables
this->m_warp = m_warp; this->m_warp = m_warp;
this->n_warp = n_warp; this->n_warp = n_warp;
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
"""Atomic operations for tilelang.""" """Atomic operations for tilelang."""
import tilelang.language as T import tilelang.language as T
from tvm import ir from tvm import ir, tir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from typing import Optional from typing import Optional
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load
_MEMORY_ORDER_ID_MAP = { _MEMORY_ORDER_ID_MAP = {
"relaxed": 0, "relaxed": 0,
...@@ -200,14 +202,17 @@ def atomic_add(dst: Buffer, ...@@ -200,14 +202,17 @@ def atomic_add(dst: Buffer,
extent = max(src_extent, dst_extent) extent = max(src_extent, dst_extent)
def _to_region(data, access_type): def _to_region(data, access_type):
from .customize import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region if isinstance(data, tir.Var) and T.has_let_value(data):
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data) data = T.get_let_value(data)
if isinstance(data, Buffer): if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type) return buffer_to_tile_region(data, access_type)
elif isinstance(data, BufferRegion): elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent) return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad):
region = get_buffer_region_from_load(data)
if region is None:
return buffer_load_to_tile_region(data, access_type, extent)
return buffer_region_to_tile_region(region, access_type, extent)
else: else:
return buffer_load_to_tile_region(data, access_type, extent) return buffer_load_to_tile_region(data, access_type, extent)
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from typing import Union, List, Optional, Literal from typing import Union, Optional, Literal
from tilelang import language as T from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
from tvm import ir, tir from tvm import ir, tir
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
"""Create a memory region descriptor for tile operations.
Args:
buffer (tir.BufferLoad): The buffer to create a region for
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
*args (tir.PrimExpr): Extent expressions defining the region size
Returns:
tir.Call: A region descriptor for tile operations
"""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
indices = load.indices
if len(indices) > len(extents):
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
new_extents = []
for _ in range(len(indices) - len(extents)):
new_extents.append(1)
for extent in extents:
new_extents.append(extent)
extents = new_extents
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: List[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion], def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
import tilelang.language as T import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op from tvm.tir import PrimExpr, Buffer, op
from typing import List, Union from typing import List, Union
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""
Create a tile memory-region descriptor for a BufferLoad.
Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
(1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
Parameters:
buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
*args (tir.PrimExpr): Extent expressions for each region dimension.
Returns:
tir.Call: A call to the `tl.region` intrinsic describing the memory region.
Raises:
KeyError: If access_type is not one of 'r', 'w', or 'rw'.
"""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
indices = load.indices
if len(indices) > len(extents):
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
new_extents = []
for _ in range(len(indices) - len(extents)):
new_extents.append(1)
for extent in extents:
new_extents.append(extent)
extents = new_extents
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
extents: List[PrimExpr]):
"""
Create a tl region descriptor for the given BufferRegion.
Parameters:
buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents.
access_type (str): Access mode: "r", "w", or "rw".
extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region.
Returns:
tir.Call: A tile-region descriptor (tl.region) covering the buffer_region.
Raises:
AssertionError: If the number of extents in buffer_region.region is smaller than len(extents).
"""
mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
"""Perform a 4-element dot product with accumulation (DP4A). """Perform a 4-element dot product with accumulation (DP4A).
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
from typing import List from typing import List
from tvm.tir import PrimExpr from tvm import tir
from tvm.tir import PrimExpr, Buffer, BufferLoad, op
from tilelang import language as T
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""
Create a tile memory-region descriptor for a BufferLoad.
Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
(1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
Parameters:
buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
*args (tir.PrimExpr): Extent expressions for each region dimension.
Returns:
tir.Call: A call to the `tl.region` intrinsic describing the memory region.
Raises:
KeyError: If access_type is not one of 'r', 'w', or 'rw'.
"""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
indices = load.indices
if len(indices) > len(extents):
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
new_extents = []
for _ in range(len(indices) - len(extents)):
new_extents.append(1)
for extent in extents:
new_extents.append(extent)
extents = new_extents
print("after extents", extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: List[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def index_to_coordinates(index, shape) -> List[PrimExpr]: def index_to_coordinates(index, shape) -> List[PrimExpr]:
......
...@@ -131,8 +131,16 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.Buf ...@@ -131,8 +131,16 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.Buf
""" """
buffer, indices = buffer_load.buffer, buffer_load.indices buffer, indices = buffer_load.buffer, buffer_load.indices
regions = [] regions = []
found_ramp: bool = False
for indice in indices: for indice in indices:
if not isinstance(indice, tir.Ramp): if isinstance(indice, tir.Ramp):
return None
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True
elif isinstance(indice, tir.PrimExpr):
regions.append(ir.Range.from_min_extent(indice, 1))
else:
raise ValueError("Unsupported type: ", type(indice))
if found_ramp:
return tir.BufferRegion(buffer, regions) return tir.BufferRegion(buffer, regions)
else:
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment