Unverified Commit 05507037 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Feature][Example] Support TMA reduce operation and update GQA bwd example (#969)



* [Feature][Example] Support TMA reduce operation and update GQA bwd example

* move GQA bwd with TMA reduce to new example

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 77b9d08e
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
@tilelang.jit(
out_idx=[3, 4, 5], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
dtype = "float16"
accum_dtype = "float"
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk,
by, :])
with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz):
T.annotate_layout({
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
})
T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk,
by, :])
T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk,
by, :])
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True)
T.copy(dv, dv_shared)
T.atomic_add(
dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
T.copy(dk, dk_shared)
T.atomic_add(
dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups
K = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
...@@ -80,7 +80,10 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { ...@@ -80,7 +80,10 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) { if (args.size() >= 3) {
node->coalesced_width = Downcast<IntImm>(args[2]); node->use_tma = Downcast<IntImm>(args[2]);
}
if (args.size() >= 4) {
node->coalesced_width = Downcast<IntImm>(args[3]);
} }
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -169,6 +172,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs, ...@@ -169,6 +172,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
return indices; return indices;
} }
std::pair<Array<PrimExpr>, PrimExpr>
AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
PrimExpr size = 1;
for (size_t i = 0; i < ranges.size(); i++) {
indices.push_back(ranges[i]->min);
size *= ranges[i]->extent;
}
return {indices, size};
}
/** /**
* @brief Build a combined bound-check predicate for indexed access. * @brief Build a combined bound-check predicate for indexed access.
* *
...@@ -350,6 +365,28 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -350,6 +365,28 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
*/ */
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target; Target target = T.target;
if (use_tma->value != 0) {
Array<PrimExpr> src_indices, dst_indices;
PrimExpr src_size, dst_size;
std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1);
ICHECK(analyzer->CanProveEqual(src_size, dst_size))
<< "src_size = " << src_size << ", dst_size = " << dst_size;
BufferLoad src_node = BufferLoad(src, src_indices);
BufferLoad dst_node = BufferLoad(dst, dst_indices);
Call address_of_src =
Call(DataType::Handle(), builtin::address_of(), {src_node});
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
int need_reduce = 1;
int eviction_policy = 0;
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
{address_of_src, address_of_dst,
ceildiv(src_size * src->dtype.bits(), 8),
need_reduce, eviction_policy}));
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
}
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = ParallelOp(fused_loop); auto par_op = ParallelOp(fused_loop);
......
...@@ -20,6 +20,7 @@ public: ...@@ -20,6 +20,7 @@ public:
Buffer src, dst; ///< Source and destination buffers Buffer src, dst; ///< Source and destination buffers
Array<Range> src_range, Array<Range> src_range,
dst_range; ///< Access ranges for source and destination dst_range; ///< Access ranges for source and destination
IntImm use_tma; ///< Whether to use TMA for memory operations
IntImm coalesced_width; ///< Width for memory coalescing optimization IntImm coalesced_width; ///< Width for memory coalescing optimization
mutable ParallelOp par_op_; ///< Associated parallel operation mutable ParallelOp par_op_; ///< Associated parallel operation
...@@ -39,6 +40,7 @@ public: ...@@ -39,6 +40,7 @@ public:
.def_ro("dst", &AtomicAddNode::dst) .def_ro("dst", &AtomicAddNode::dst)
.def_ro("src_range", &AtomicAddNode::src_range) .def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range) .def_ro("dst_range", &AtomicAddNode::dst_range)
.def_ro("use_tma", &AtomicAddNode::use_tma)
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width); .def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
} }
...@@ -46,6 +48,7 @@ public: ...@@ -46,6 +48,7 @@ public:
return equal(src, other->src) && equal(dst, other->dst) && return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) && equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) && equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width); equal(coalesced_width, other->coalesced_width);
} }
...@@ -54,6 +57,7 @@ public: ...@@ -54,6 +57,7 @@ public:
hash_reduce(dst); hash_reduce(dst);
hash_reduce(src_range); hash_reduce(src_range);
hash_reduce(dst_range); hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width); hash_reduce(coalesced_width);
} }
...@@ -67,6 +71,8 @@ protected: ...@@ -67,6 +71,8 @@ protected:
Array<IterVar> MakeIterVars() const; Array<IterVar> MakeIterVars() const;
/// Generate buffer indices from iteration variables /// Generate buffer indices from iteration variables
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const; Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
/// Return buffer indices and size
std::pair<Array<PrimExpr>, PrimExpr> ReturnIndicesAndSize(int src_dst) const;
/// Create boundary predicate for memory safety /// Create boundary predicate for memory safety
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs, PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const; Array<PrimExpr> extents, int src_dst) const;
......
...@@ -1571,6 +1571,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -1571,6 +1571,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords) for (auto coord : global_coords)
args.push_back(coord); args.push_back(coord);
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy); args.push_back(this->eviction_policy);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args))); Evaluate(Call(DataType::Handle(), op, args)));
...@@ -1580,6 +1583,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -1580,6 +1583,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
args.push_back(shared_addr); args.push_back(shared_addr);
for (auto coord : global_coords) for (auto coord : global_coords)
args.push_back(coord); args.push_back(coord);
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy); args.push_back(this->eviction_policy);
tma_copy = Evaluate(Call(DataType::Handle(), op, args)); tma_copy = Evaluate(Call(DataType::Handle(), op, args));
} }
...@@ -1654,10 +1660,11 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -1654,10 +1660,11 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
{shared_addr, global_addr, 0, {shared_addr, global_addr, 0,
elements * shared_tensor->dtype.bytes(), this->eviction_policy})); elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
} else { } else {
int need_reduce = 0;
tma_copy = Evaluate( tma_copy = Evaluate(
Call(DataType::Handle(), tma_store(), Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
this->eviction_policy})); need_reduce, this->eviction_policy}));
} }
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy; return tma_copy;
......
...@@ -1345,6 +1345,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1345,6 +1345,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt(ss.str(), 0, 1); print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store())) { } else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss; std::stringstream ss;
auto need_reduce = op->args[op->args.size() - 2].as<IntImmNode>()->value;
if (need_reduce) {
print_extern_call_stmt("tl::tma_store_add", 0, 2);
return;
}
auto eviction_policy = auto eviction_policy =
this->eviction_policy_names_ this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value]; [op->args[op->args.size() - 1].as<IntImmNode>()->value];
...@@ -1353,7 +1358,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1353,7 +1358,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else { } else {
ss << "tl::tma_store"; ss << "tl::tma_store";
} }
print_extern_call_stmt(ss.str(), 0, 1); print_extern_call_stmt(ss.str(), 0, 2);
} else if (op->op.same_as(tl::ptx_ldmatrix())) { } else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
......
...@@ -252,6 +252,16 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor, ...@@ -252,6 +252,16 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
: "memory"); : "memory");
} }
TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr,
int32_t const &store_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 "
"[%0], [%1], %2;\n"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) { TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
......
...@@ -116,7 +116,8 @@ def atomic_min(dst: Buffer, ...@@ -116,7 +116,8 @@ def atomic_min(dst: Buffer,
def atomic_add(dst: Buffer, def atomic_add(dst: Buffer,
value: PrimExpr, value: PrimExpr,
memory_order: Optional[str] = None, memory_order: Optional[str] = None,
return_prev: bool = False) -> PrimExpr: return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
""" """
Atomically add `value` into `dst`, returning a handle to the operation. Atomically add `value` into `dst`, returning a handle to the operation.
...@@ -225,7 +226,7 @@ def atomic_add(dst: Buffer, ...@@ -225,7 +226,7 @@ def atomic_add(dst: Buffer,
raise NotImplementedError( raise NotImplementedError(
"return_prev is not supported for tile-region-based atomic operations") "return_prev is not supported for tile-region-based atomic operations")
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst) return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma)
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment