Commit 9b0bc978 authored by Tri Dao's avatar Tri Dao
Browse files

Fix race condition in Triton fwd

parent 215930bc
""" """
Based on the FlashAttention implementation from Phil Tillet. We use the FlashAttention implementation from Phil Tillet a starting point.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes: Changes:
...@@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp ...@@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp
- Make the backward for d=128 much faster by reducing register spilling. - Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads. small batch size * nheads.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
slower in other cases.
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
""" """
import math import math
...@@ -26,7 +33,8 @@ import triton.language as tl ...@@ -26,7 +33,8 @@ import triton.language as tl
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), # This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
], ],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
) )
...@@ -34,6 +42,7 @@ import triton.language as tl ...@@ -34,6 +42,7 @@ import triton.language as tl
{ {
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
# "EVEN_N": lambda args: False,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
} }
) )
...@@ -95,7 +104,7 @@ def _fwd_kernel( ...@@ -95,7 +104,7 @@ def _fwd_kernel(
for start_n in range(0, end_n, BLOCK_N): for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ---- # -- compute qk ----
if EVEN_N: if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM: if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn) k = tl.load(k_ptrs + start_n * stride_kn)
else: else:
...@@ -129,7 +138,7 @@ def _fwd_kernel( ...@@ -129,7 +138,7 @@ def _fwd_kernel(
acc_o_scale = tl.load(t_ptrs) acc_o_scale = tl.load(t_ptrs)
acc_o = acc_o * acc_o_scale[:, None] acc_o = acc_o * acc_o_scale[:, None]
# update acc_o # update acc_o
if EVEN_N: if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM: if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn) v = tl.load(v_ptrs + start_n * stride_vn)
else: else:
...@@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block( ...@@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block(
# compute dp = dot(v, do) # compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong. # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
tl.debug_barrier() if not EVEN_M:
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True) dp = tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None]) # compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster # Putting the subtraction after the dp matmul (instead of before) is slightly faster
......
...@@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output_0 = flash_attn_func(q, k, v, causal)
g = torch.randn_like(output_0)
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
for i in range(10000):
output = flash_attn_func(q, k, v, causal)
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert torch.equal(output, output_0)
# assert torch.equal(dq, dq_0)
# assert torch.equal(dk, dk_0)
# assert torch.equal(dv, dv_0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment