Commit 008951f1 authored by Tri Dao's avatar Tri Dao
Browse files

Support all head dimensions up to 128 in the Triton fwd

parent b910bf14
......@@ -6,6 +6,7 @@ Changes:
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), in the forward pass.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
......@@ -31,6 +32,7 @@ import triton.language as tl
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
......@@ -42,11 +44,11 @@ def _fwd_kernel(
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
......@@ -76,19 +78,34 @@ def _fwd_kernel(
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
if not EVEN_N:
......@@ -111,10 +128,18 @@ def _fwd_kernel(
acc_o = acc_o * acc_o_scale[:, None]
# update acc_o
if EVEN_N:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
v = tl.load(v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
......@@ -138,9 +163,16 @@ def _fwd_kernel(
offs_n = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :])
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(out_ptrs, acc_o,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
@triton.heuristics(
......@@ -209,8 +241,8 @@ def _bwd_kernel_one_col_block(
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_N=False,
# if we just call # tl.load(k_ptrs), we get the wrong output!
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M:
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
......@@ -390,7 +422,8 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
_, seqlen_k, _, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d)
assert d in {16, 32, 64, 128}
assert d <= 128, 'FlashAttention only support head dimensions up to 128'
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
assert q.is_cuda and k.is_cuda and v.is_cuda
......@@ -413,11 +446,11 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1),
o.stride(0), o.stride(2), o.stride(1),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal, d,
causal, BLOCK_HEADDIM,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps,
# num_stages=1,
......@@ -431,6 +464,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do = do.contiguous()
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
assert d in {16, 32, 64, 128}
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
......
......@@ -861,7 +861,7 @@ from flash_attn.flash_attn_triton import flash_attn_func
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
@pytest.mark.parametrize('d', [40, 64, 128, 88])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
......@@ -887,6 +887,8 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
run_bwd = d in [16, 32, 64, 128]
if run_bwd:
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
......@@ -903,6 +905,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if run_bwd:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment