"docs/source/en/training/overview.md" did not exist on "ba87c1607cae2ae00ab2547e911a101ed27ea18b"
Commit dc554693 authored by Tri Dao's avatar Tri Dao
Browse files

Support arbitrary seqlen_k in Triton bwd

parent d11341fd
...@@ -3,11 +3,14 @@ Based on the FlashAttention implementation from Phil Tillet. ...@@ -3,11 +3,14 @@ Based on the FlashAttention implementation from Phil Tillet.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes: Changes:
- Support both causal and non-causal attention. - Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128) in the forward pass. - Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Support arbitrary seqlen_k (not just multiples of 128) in the backward pass. However, seqlen_q
must still be a multiple of 128.
- Speed up the forward pass a bit (and only store the LSE instead of m and l). - Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling. - Make the backward for d=128 much faster by reducing register spilling.
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads. small batch size * nheads.
""" """
...@@ -190,6 +193,7 @@ def _bwd_kernel_one_col_block( ...@@ -190,6 +193,7 @@ def _bwd_kernel_one_col_block(
ATOMIC_ADD: tl.constexpr, ATOMIC_ADD: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
...@@ -209,8 +213,12 @@ def _bwd_kernel_one_col_block( ...@@ -209,8 +213,12 @@ def _bwd_kernel_one_col_block(
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# k and v stay in SRAM throughout # k and v stay in SRAM throughout
k = tl.load(k_ptrs) if EVEN_N:
v = tl.load(v_ptrs) k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
else:
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
# loop over rows # loop over rows
num_block_m = tl.cdiv(seqlen_q, BLOCK_M) num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
...@@ -220,6 +228,8 @@ def _bwd_kernel_one_col_block( ...@@ -220,6 +228,8 @@ def _bwd_kernel_one_col_block(
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T # recompute p = softmax(qk, dim=-1).T
qk = tl.dot(q, k, trans_b=True) qk = tl.dot(q, k, trans_b=True)
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
lse_i = tl.load(LSE + offs_m_curr) lse_i = tl.load(LSE + offs_m_curr)
...@@ -252,8 +262,12 @@ def _bwd_kernel_one_col_block( ...@@ -252,8 +262,12 @@ def _bwd_kernel_one_col_block(
# write-back # write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :]) dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :])
tl.store(dv_ptrs, dv) if EVEN_N:
tl.store(dk_ptrs, dk) tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
else:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
def init_to_zero(name): def init_to_zero(name):
...@@ -282,6 +296,12 @@ def init_to_zero(name): ...@@ -282,6 +296,12 @@ def init_to_zero(name):
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
# reset_to_zero=['DQ'] # reset_to_zero=['DQ']
) )
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0,
}
)
@triton.jit @triton.jit
def _bwd_kernel( def _bwd_kernel(
Q, K, V, Q, K, V,
...@@ -300,6 +320,7 @@ def _bwd_kernel( ...@@ -300,6 +320,7 @@ def _bwd_kernel(
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
off_hb = tl.program_id(1) off_hb = tl.program_id(1)
...@@ -329,6 +350,7 @@ def _bwd_kernel( ...@@ -329,6 +350,7 @@ def _bwd_kernel(
ATOMIC_ADD=False, ATOMIC_ADD=False,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
) )
else: else:
...@@ -343,6 +365,7 @@ def _bwd_kernel( ...@@ -343,6 +365,7 @@ def _bwd_kernel(
ATOMIC_ADD=True, ATOMIC_ADD=True,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
) )
...@@ -394,8 +417,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -394,8 +417,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do = do.contiguous() do = do.contiguous()
batch, seqlen_q, nheads, d = q.shape batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape _, seqlen_k, _, _ = k.shape
assert seqlen_q % 128 == 0, 'Backward pass currently only support seqlen that are multiples of 128' assert seqlen_q % 128 == 0, 'Backward pass currently only supports seqlens that are multiples of 128'
assert seqlen_k % 128 == 0, 'Backward pass currently only support seqlen that are multiples of 128'
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded) assert lse.shape == (batch, nheads, seqlen_q_rounded)
# dq_accum = torch.zeros_like(q, dtype=torch.float32) # dq_accum = torch.zeros_like(q, dtype=torch.float32)
......
...@@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func ...@@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [64, 128]) @pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (256, 512), (512, 256), (1024, 1024), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(127, 256)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 211)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
...@@ -887,7 +887,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -887,7 +887,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
run_bwd = (seqlen_q % 128 == 0) and (seqlen_k % 128 == 0) run_bwd = seqlen_q % 128 == 0
if run_bwd: if run_bwd:
g = torch.randn_like(output) g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment