Commit 008951f1 authored by Tri Dao's avatar Tri Dao
Browse files

Support all head dimensions up to 128 in the Triton fwd

parent b910bf14
...@@ -6,6 +6,7 @@ Changes: ...@@ -6,6 +6,7 @@ Changes:
- Implement both causal and non-causal attention. - Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention). - Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), in the forward pass.
- Speed up the forward pass a bit, and only store the LSE instead of m and l. - Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling. - Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
...@@ -31,6 +32,7 @@ import triton.language as tl ...@@ -31,6 +32,7 @@ import triton.language as tl
{ {
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
} }
) )
@triton.jit @triton.jit
...@@ -42,11 +44,11 @@ def _fwd_kernel( ...@@ -42,11 +44,11 @@ def _fwd_kernel(
stride_kb, stride_kh, stride_kn, stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn, stride_vb, stride_vh, stride_vn,
stride_ob, stride_oh, stride_om, stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
start_m = tl.program_id(0) start_m = tl.program_id(0)
...@@ -76,19 +78,34 @@ def _fwd_kernel( ...@@ -76,19 +78,34 @@ def _fwd_kernel(
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call # [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler? # tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
if EVEN_M & EVEN_N: if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
else: else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator # loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N): for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ---- # -- compute qk ----
if EVEN_N: if EVEN_N:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn) k = tl.load(k_ptrs + start_n * stride_kn)
else: else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0) other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True) qk += tl.dot(q, k, trans_b=True)
if not EVEN_N: if not EVEN_N:
...@@ -111,10 +128,18 @@ def _fwd_kernel( ...@@ -111,10 +128,18 @@ def _fwd_kernel(
acc_o = acc_o * acc_o_scale[:, None] acc_o = acc_o * acc_o_scale[:, None]
# update acc_o # update acc_o
if EVEN_N: if EVEN_N:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn) v = tl.load(v_ptrs + start_n * stride_vn)
else: else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0) other=0.0)
else:
v = tl.load(v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
p = p.to(v.dtype) p = p.to(v.dtype)
acc_o += tl.dot(p, v) acc_o += tl.dot(p, v)
...@@ -138,9 +163,16 @@ def _fwd_kernel( ...@@ -138,9 +163,16 @@ def _fwd_kernel(
offs_n = tl.arange(0, BLOCK_HEADDIM) offs_n = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :]) out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :])
if EVEN_M: if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o) tl.store(out_ptrs, acc_o)
else: else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(out_ptrs, acc_o,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
@triton.heuristics( @triton.heuristics(
...@@ -209,8 +241,8 @@ def _bwd_kernel_one_col_block( ...@@ -209,8 +241,8 @@ def _bwd_kernel_one_col_block(
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# k and v stay in SRAM throughout # k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_N=False, # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call # tl.load(k_ptrs), we get the wrong output! # if we just call tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M: if EVEN_N & EVEN_M:
k = tl.load(k_ptrs) k = tl.load(k_ptrs)
v = tl.load(v_ptrs) v = tl.load(v_ptrs)
...@@ -390,7 +422,8 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -390,7 +422,8 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
_, seqlen_k, _, _ = k.shape _, seqlen_k, _, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads, d) assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d)
assert d in {16, 32, 64, 128} assert d <= 128, 'FlashAttention only support head dimensions up to 128'
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
assert q.is_cuda and k.is_cuda and v.is_cuda assert q.is_cuda and k.is_cuda and v.is_cuda
...@@ -413,11 +446,11 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -413,11 +446,11 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
k.stride(0), k.stride(2), k.stride(1), k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1), v.stride(0), v.stride(2), v.stride(1),
o.stride(0), o.stride(2), o.stride(1), o.stride(0), o.stride(2), o.stride(1),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal, d, causal, BLOCK_HEADDIM,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK, # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps, # num_warps=num_warps,
# num_stages=1, # num_stages=1,
...@@ -431,6 +464,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -431,6 +464,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do = do.contiguous() do = do.contiguous()
batch, seqlen_q, nheads, d = q.shape batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape _, seqlen_k, _, _ = k.shape
assert d in {16, 32, 64, 128}
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded) assert lse.shape == (batch, nheads, seqlen_q_rounded)
# dq_accum = torch.zeros_like(q, dtype=torch.float32) # dq_accum = torch.zeros_like(q, dtype=torch.float32)
......
...@@ -861,7 +861,7 @@ from flash_attn.flash_attn_triton import flash_attn_func ...@@ -861,7 +861,7 @@ from flash_attn.flash_attn_triton import flash_attn_func
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128]) @pytest.mark.parametrize('d', [40, 64, 128, 88])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
...@@ -887,6 +887,8 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -887,6 +887,8 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
run_bwd = d in [16, 32, 64, 128]
if run_bwd:
g = torch.randn_like(output) g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
...@@ -903,6 +905,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -903,6 +905,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if run_bwd:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment