Commit d11341fd authored by Tri Dao's avatar Tri Dao
Browse files

Fix Triton fwd to support seqlen not multiples of 128

parent b0c0db81
...@@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention ...@@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes: Changes:
- Support both causal and non-causal attention. - Support both causal and non-causal attention.
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Speed up the forward pass a bit (and only store the LSE instead of m and l). - Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling. - Make the backward for d=128 much faster by reducing register spilling.
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of - Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
...@@ -30,7 +31,7 @@ import triton.language as tl ...@@ -30,7 +31,7 @@ import triton.language as tl
@triton.heuristics( @triton.heuristics(
{ {
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
} }
) )
@triton.jit @triton.jit
...@@ -42,7 +43,7 @@ def _fwd_kernel( ...@@ -42,7 +43,7 @@ def _fwd_kernel(
stride_kb, stride_kh, stride_kn, stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn, stride_vb, stride_vh, stride_vn,
stride_ob, stride_oh, stride_om, stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
...@@ -68,12 +69,14 @@ def _fwd_kernel( ...@@ -68,12 +69,14 @@ def _fwd_kernel(
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
# initialize pointer to m and l # initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q + offs_m t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout # load q: it will stay in SRAM throughout
if EVEN_M: # [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
if EVEN_M & EVEN_N:
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
else: else:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
...@@ -130,7 +133,7 @@ def _fwd_kernel( ...@@ -130,7 +133,7 @@ def _fwd_kernel(
start_m = tl.program_id(0) start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m # write back l and m
lse_ptrs = Lse + off_hb * seqlen_q + offs_m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i) tl.store(lse_ptrs, lse_i)
# initialize pointers to output # initialize pointers to output
offs_n = tl.arange(0, BLOCK_HEADDIM) offs_n = tl.arange(0, BLOCK_HEADDIM)
...@@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
k.stride(0), k.stride(2), k.stride(1), k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1), v.stride(0), v.stride(2), v.stride(1),
o.stride(0), o.stride(2), o.stride(1), o.stride(0), o.stride(2), o.stride(1),
nheads, seqlen_q, seqlen_k, nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
......
...@@ -855,15 +855,17 @@ def test_flash_attn_multigpu(): ...@@ -855,15 +855,17 @@ def test_flash_attn_multigpu():
from flash_attn.flash_attn_triton import flash_attn_func from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128]) @pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(127, 256)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
...@@ -885,22 +887,25 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -885,22 +887,25 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
g = torch.randn_like(output) run_bwd = (seqlen_q % 128 == 0) and (seqlen_k % 128 == 0)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) if run_bwd:
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) g = torch.randn_like(output)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() if run_bwd:
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment