"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "6e6ae26a3d5a60967e52603cf10ed2299f7077c0"
Commit b0c0db81 authored by Tri Dao's avatar Tri Dao
Browse files

Implement FlashAttention in Triton

parent c422fee3
...@@ -6,9 +6,11 @@ import torch.nn.functional as F ...@@ -6,9 +6,11 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_all, pytorch_profiler from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.triton.fused_attention import attention as attention # from flash_attn.triton.fused_attention import attention as attention
from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
from flash_attn.flash_attn_triton_og import attention as attention_og
try: try:
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
...@@ -45,19 +47,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): ...@@ -45,19 +47,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
return output.to(dtype=qkv.dtype) return output.to(dtype=qkv.dtype)
def attention_triton(q, k, v):
"""
No dropout and only support causal=True.
Triton implementation seems to require q, k, v being contiguous?
Arguments:
q, k, v: (batch_size, nheads, seqlen, head_dim)
Output:
output: (batch_size, nheads, seqlen, head_dim)
"""
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
return attention(q, k, v, softmax_scale)
def attention_megatron(qkv): def attention_megatron(qkv):
""" """
Arguments: Arguments:
...@@ -85,6 +74,10 @@ batch_size = 2 ...@@ -85,6 +74,10 @@ batch_size = 2
seqlen = 4096 seqlen = 4096
nheads = 12 nheads = 12
headdim = 128 headdim = 128
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p = 0.0 dropout_p = 0.0
causal = True causal = True
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -100,9 +93,13 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b ...@@ -100,9 +93,13 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
repeats=repeats, desc='PyTorch Attention') repeats=repeats, desc='PyTorch Attention')
benchmark_all(flash_attn_qkvpacked_func, qkv, causal, repeats=repeats, desc='FlashAttention Triton')
pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal, backward=True)
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)] requires_grad=True) for _ in range(3)]
benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton') benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
if scaled_upper_triang_masked_softmax is not None: if scaled_upper_triang_masked_softmax is not None:
benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
This diff is collapsed.
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py # [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
# for benchmarking. # for benchmarking.
# Fixing some dtype casting to make it work for bfloat16 # We fixed a few dtype cast to make it work for bf16
""" """
Fused Attention Fused Attention
...@@ -78,7 +78,7 @@ def _fwd_kernel( ...@@ -78,7 +78,7 @@ def _fwd_kernel(
acc = acc * acc_scale[:, None] acc = acc * acc_scale[:, None]
# update acc # update acc
v = tl.load(v_ptrs + start_n * stride_vk) v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(q.dtype) p = p.to(v.dtype)
acc += tl.dot(p, v) acc += tl.dot(p, v)
# update m_i and l_i # update m_i and l_i
l_i = l_i_new l_i = l_i_new
...@@ -178,7 +178,7 @@ def _bwd_kernel( ...@@ -178,7 +178,7 @@ def _bwd_kernel(
p = tl.exp(qk * sm_scale - m[:, None]) p = tl.exp(qk * sm_scale - m[:, None])
# compute dv # compute dv
do = tl.load(do_ptrs) do = tl.load(do_ptrs)
dv += tl.dot(p.to(q.dtype), do, trans_a=True) dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do) # compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr) Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
...@@ -189,7 +189,7 @@ def _bwd_kernel( ...@@ -189,7 +189,7 @@ def _bwd_kernel(
dk += tl.dot(ds.to(q.dtype), q, trans_a=True) dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
# # compute dq # # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last") dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(q.dtype), k) dq += tl.dot(ds.to(k.dtype), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last") tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers # # increment pointers
dq_ptrs += BLOCK_M * stride_qm dq_ptrs += BLOCK_M * stride_qm
...@@ -270,95 +270,7 @@ class _attention(torch.autograd.Function): ...@@ -270,95 +270,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
return dq, dk, dv, None return dq.to(q.dtype), dk, dv, None
attention = _attention.apply attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)
...@@ -160,6 +160,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -160,6 +160,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v) # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None: if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0) attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling) output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling)
if query_padding_mask is not None: if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0) output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0)
...@@ -849,3 +851,56 @@ def test_flash_attn_multigpu(): ...@@ -849,3 +851,56 @@ def test_flash_attn_multigpu():
assert 0.99 <= dropout_fraction / dropout_p <= 1.01 assert 0.99 <= dropout_fraction / dropout_p <= 1.01
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 8
nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output = flash_attn_func(q, k, v, causal)
output_ref, attn_ref = attention_ref(q, k, v, causal=causal)
output_pt, attn_pt = attention_ref(q, k, v, causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment