Commit 86862cfd authored by Tri Dao's avatar Tri Dao
Browse files

Implement attention bias for Triton version

parent 470010f5
This diff is collapsed.
...@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None ...@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
dropout_mask=None, causal=False, upcast=True, reorder_ops=False): dropout_mask=None, causal=False, bias=None, upcast=True, reorder_ops=False):
""" """
Arguments: Arguments:
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
...@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
key_padding_mask: (batch_size, seqlen_k) key_padding_mask: (batch_size, seqlen_k)
dropout_p: float dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
bias: (batch_size, nheads, seqlen_q, seqlen_k)
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16. output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
else: else:
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
if bias is not None:
scores = (scores + bias).to(dtype=scores.dtype)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal: if causal:
...@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func ...@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [48]) # @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): @pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
# @pytest.mark.parametrize('bias_shape', (['1h1k']))
def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = 'cuda'
...@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
nheads = 4 nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
if bias_shape == '1h1k':
bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == '1hqk':
bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b11k':
bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b1qk':
bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device)
else:
bias = None
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output = flash_attn_func(q, k, v, causal) output = flash_attn_func(q, k, v, bias, causal)
output_ref, attn_ref = attention_ref(q, k, v, causal=causal) output_ref, attn_ref = attention_ref(q, k, v, bias=bias, causal=causal)
output_pt, attn_pt = attention_ref(q, k, v, causal=causal, upcast=False, reorder_ops=True) output_pt, attn_pt = attention_ref(q, k, v, bias=bias, causal=causal, upcast=False,
reorder_ops=True)
print(f'Output max diff: {(output - output_ref).abs().max().item()}') print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
...@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
@pytest.mark.parametrize('d', [64, 128]) # @pytest.mark.parametrize('d', [96])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)])
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): @pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = 'cuda'
...@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
nheads = 4 nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
if bias_shape == '1h1k':
bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == '1hqk':
bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b11k':
bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b1qk':
bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device)
else:
bias = None
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output_0 = flash_attn_func(q, k, v, causal) output_0 = flash_attn_func(q, k, v, bias, causal)
g = torch.randn_like(output_0) g = torch.randn_like(output_0)
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g) dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic # The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq = False deterministic_dq = False
equal_fn = (torch.equal if deterministic_dq # Numerical error if we just do any arithmetic on dq
else partial(torch.allclose, atol=1e-3 if dtype == torch.bfloat16 else 1e-5)) dq_atol = ((dq_0 + 0.3 - 0.3) - dq_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
# Run 10000 times and check that the results don't change
for i in range(10000): for i in range(10000):
output = flash_attn_func(q, k, v, causal) output = flash_attn_func(q, k, v, None, causal)
output_equal = torch.equal(output, output_0) output_equal = torch.equal(output, output_0)
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
print(f'Output max diff: {(output - output_0).abs().max().item()}') print(f'Output max diff: {(output - output_0).abs().max().item()}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment