Commit 0e8c46ae authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on test files

parent 7fcd3e6a
import math
from functools import partial
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('return_residual', [False, True])
@pytest.mark.parametrize('has_bias', [True, False])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_residual", [False, True])
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("out_features", [1024, 4096])
@pytest.mark.parametrize("in_features", [1024, 4096])
def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
device = 'cuda'
device = "cuda"
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(
batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
model = FusedDense(in_features, out_features, bias=has_bias, return_residual=return_residual,
device=device, dtype=dtype)
model = FusedDense(
in_features,
out_features,
bias=has_bias,
return_residual=return_residual,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
if has_bias:
......@@ -37,10 +42,16 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
out = model(x)
else:
out, x_copy = model(x)
x_copy = (x_copy[..., :out_features] if out_features < in_features
else F.pad(x_copy, (0, out_features - in_features)))
x_pt_copy = (x_pt[..., :out_features] if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features)))
x_copy = (
x_copy[..., :out_features]
if out_features < in_features
else F.pad(x_copy, (0, out_features - in_features))
)
x_pt_copy = (
x_pt[..., :out_features]
if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features))
)
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt_copy)
out = out + F.gelu(x_copy)
......@@ -60,43 +71,64 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('heuristic', ['auto', -1])
@pytest.mark.parametrize("heuristic", ["auto", -1])
# @pytest.mark.parametrize('heuristic', ['auto'])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize("checkpoint_lvl", [0, 1, 2])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@pytest.mark.parametrize('return_residual', [False, True])
@pytest.mark.parametrize("return_residual", [False, True])
# @pytest.mark.parametrize('return_residual', [False])
@pytest.mark.parametrize('has_bias2', [True, False])
@pytest.mark.parametrize('has_bias1', [True, False])
@pytest.mark.parametrize("has_bias2", [True, False])
@pytest.mark.parametrize("has_bias1", [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@pytest.mark.parametrize('activation', ['gelu_approx', 'relu'])
@pytest.mark.parametrize("activation", ["gelu_approx", "relu"])
# @pytest.mark.parametrize('activation', ['relu'])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
@pytest.mark.parametrize("out_features", [1024, 4096])
@pytest.mark.parametrize("in_features", [1024, 4096])
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, return_residual,
checkpoint_lvl, heuristic, dtype):
device = 'cuda'
def test_fused_mlp(
in_features,
out_features,
activation,
has_bias1,
has_bias2,
return_residual,
checkpoint_lvl,
heuristic,
dtype,
):
device = "cuda"
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(
batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, bias=has_bias1, device=device,
dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
dtype=dtype)
model = FusedMLP(in_features, out_features, in_features, activation=activation,
bias1=has_bias1, bias2=has_bias2, return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
model_pt_fc1 = torch.nn.Linear(
in_features, out_features, bias=has_bias1, device=device, dtype=dtype
)
model_pt_fc2 = torch.nn.Linear(
out_features, in_features, bias=has_bias2, device=device, dtype=dtype
)
model = FusedMLP(
in_features,
out_features,
in_features,
activation=activation,
bias1=has_bias1,
bias2=has_bias2,
return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl,
heuristic=heuristic,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
if has_bias1:
......@@ -104,8 +136,11 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
model.fc2.weight.copy_(model_pt_fc2.weight)
if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else partial(F.relu, inplace=True))
activation_fn = (
partial(F.gelu, approximate="tanh")
if activation == "gelu_approx"
else partial(F.relu, inplace=True)
)
out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
if not return_residual:
out = model(x)
......@@ -121,13 +156,17 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
out_pt.backward(g)
out.backward(g)
# The error for relu is higher still
if activation == 'relu':
if activation == "relu":
atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(
model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10
)
if has_bias1:
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(
model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10
)
if has_bias2:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
......@@ -3,36 +3,33 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedMLP
from apex.transformer import parallel_state, tensor_parallel
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias', [True, False])
@pytest.mark.parametrize("has_bias", [True, False])
# @pytest.mark.parametrize('has_bias', [False])
@pytest.mark.parametrize('out_features', [1024])
@pytest.mark.parametrize('in_features', [4096])
def test_fused_linear_bias(in_features, out_features, has_bias, sequence_parallel,
world_size, dtype):
@pytest.mark.parametrize("out_features", [1024])
@pytest.mark.parametrize("in_features", [4096])
def test_fused_linear_bias(
in_features, out_features, has_bias, sequence_parallel, world_size, dtype
):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -41,77 +38,95 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
batch_size = 2
seqlen = 512
assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(
batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
partition_out_features = out_features // world_size
model = ColumnParallelLinear(in_features, out_features,
parallel_state.get_tensor_model_parallel_group(), bias=has_bias,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = ColumnParallelLinear(
in_features,
out_features,
parallel_state.get_tensor_model_parallel_group(),
bias=has_bias,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.weight.copy_(
model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
)
if has_bias:
model.bias.copy_(
model_pt.bias[rank * partition_out_features:(rank + 1) * partition_out_features]
model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
)
out = model(x)
out_pt = model_pt(x_pt)
assert torch.allclose(
out, out_pt[:, rank * partition_out_features:(rank + 1) * partition_out_features],
rtol=rtol, atol=atol
out,
out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol,
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out_pt) / 32
out_pt.backward(g)
out.backward(g[:, rank * partition_out_features:(rank + 1) * partition_out_features])
out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features])
parallel_state.destroy_model_parallel()
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.weight.grad,
model_pt.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 10
model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol * 10,
)
if has_bias:
assert torch.allclose(
model.bias.grad,
model_pt.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 5
model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol * 5,
)
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias2', [True, False])
@pytest.mark.parametrize("has_bias2", [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize('out_features', [4096])
@pytest.mark.parametrize('in_features', [1024])
@pytest.mark.parametrize("out_features", [4096])
@pytest.mark.parametrize("in_features", [1024])
def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -120,77 +135,103 @@ def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, worl
batch_size = 2
seqlen = 512
assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(
batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
dtype=dtype)
model_pt_fc2 = torch.nn.Linear(
out_features, in_features, bias=has_bias2, device=device, dtype=dtype
)
partition_out_features = out_features // world_size
partition_in_features = in_features // world_size
model = ParallelFusedMLP(in_features, out_features, in_features,
process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device, dtype=dtype)
model = ParallelFusedMLP(
in_features,
out_features,
in_features,
process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(
model_pt_fc1.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
)
model.fc1.bias.copy_(
model_pt_fc1.bias[rank * partition_out_features:(rank + 1) * partition_out_features]
model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
)
model.fc2.weight.copy_(
model_pt_fc2.weight[:, rank * partition_out_features:(rank + 1) * partition_out_features]
model_pt_fc2.weight[
:, rank * partition_out_features : (rank + 1) * partition_out_features
]
)
if has_bias2 and rank == 0:
model.fc2.bias.copy_(model_pt_fc2.bias)
out = model(x)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate="tanh"))
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.fc1.weight.grad,
model_pt_fc1.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 10
model_pt_fc1.weight.grad[
rank * partition_out_features : (rank + 1) * partition_out_features
],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.fc1.bias.grad,
model_pt_fc1.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 5
model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.fc2.weight.grad,
model_pt_fc2.weight.grad[:, rank * partition_out_features:(rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 10
model_pt_fc2.weight.grad[
:, rank * partition_out_features : (rank + 1) * partition_out_features
],
rtol=rtol,
atol=atol * 10,
)
if has_bias2 and rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from flash_attn import flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
from flash_attn import flash_attn_varlen_func
from flash_attn import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
is_sm8x = torch.cuda.get_device_capability('cuda')[0] == 8
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
is_sm90 = torch.cuda.get_device_capability('cuda') == (9, 0)
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
assert mode in ['full', 'random', 'third']
if mode == 'full':
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == 'random':
elif mode == "random":
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
elif mode == 'third':
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
return padding_mask
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
kvpacked=False, qkvpacked=False):
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
......@@ -53,22 +57,28 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, 'b s h d -> (b s) h d')
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q_unpad.device)
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size)
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, 'b s h d -> (b s) h d')
v_unpad = rearrange(v, 'b s h d -> (b s) h d')
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=k_unpad.device)
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
......@@ -79,9 +89,17 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q,
qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn)
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
......@@ -89,27 +107,57 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(),
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
q.detach().requires_grad_(), kv.detach().requires_grad_(),
output_pad_fn, dq_pad_fn, dkv_pad_fn)
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size)
return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
q.detach().requires_grad_(), k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn, dq_pad_fn, dk_pad_fn)
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
......@@ -136,14 +184,16 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
causal_mask = torch.triu(
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
)
scores.masked_fill_(causal_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
......@@ -152,25 +202,59 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0)
attention = attention.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(q, kv, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
return attention_ref(q, kv[:, :, 0], kv[:, :, 1], query_padding_mask,
key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal,
reorder_ops=reorder_ops)
def attention_kvpacked_ref(
q,
kv,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
return attention_ref(
q,
kv[:, :, 0],
kv[:, :, 1],
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
reorder_ops=reorder_ops,
)
def attention_qkvpacked_ref(qkv, key_padding_mask=None, dropout_p=0.0,
dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
return attention_ref(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask,
key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal,
reorder_ops=reorder_ops)
def attention_qkvpacked_ref(
qkv,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
reorder_ops=reorder_ops,
)
def generate_sparsity_mask(seqlen, sparsity=0.3):
......@@ -182,7 +266,7 @@ def generate_sparsity_mask(seqlen, sparsity=0.3):
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow, ncol = seqlen // 16, seqlen // 256
mask = torch.rand(nrow, ncol, device='cuda') < sparsity
mask = torch.rand(nrow, ncol, device="cuda") < sparsity
return mask
......@@ -201,22 +285,23 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
q, k, v = qkv.float().unbind(dim=2)
d = qkv.shape[-1]
seqlen = qkv.shape[1]
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
blockmask = repeat(blockmask, 's_16 s_256 -> (s_16 16) (s_256 256)')
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
blockmask = blockmask[:seqlen, :seqlen]
scores.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), float('-inf'))
scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
attention = torch.softmax(scores, dim=-1)
attention = attention.masked_fill(rearrange(~attn_mask, 'b s -> b 1 s 1'), 0.0)
attention = attention.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), 0.0)
attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1 1'), 0)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, head_dim, is_dropout,
causal=False):
def convert_flash_attn_S_to_softmax(
S, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
......@@ -229,12 +314,27 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(S, 'b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)',
blocksize_m=blocksize_m, blocksize_n=blocksize_n)
S_converted = rearrange(S_flat, 'b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)',
mmas_n=mmas_n, warps_n=warps_n, eight=8, c0=2, c1=2, c2=2, four=4)
S_flat = rearrange(
S,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
causal_mask = torch.triu(
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1
)
S_converted.masked_fill_(causal_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
......@@ -245,14 +345,14 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
else:
query_padding_mask = query_padding_mask[:, :seqlen_q]
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
if seqlen_k_og < seqlen_k:
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
else:
key_padding_mask = key_padding_mask[:, :seqlen_k]
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0)
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
if seqlen_q_og < seqlen_q:
S_converted = S_converted[:, :, :seqlen_q_og, :]
else:
......@@ -264,8 +364,16 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
return S_converted
def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None,
is_dropout=False, causal=False):
def normalize_flash_attn_S(
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
is_dropout=False,
causal=False,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
......@@ -278,12 +386,14 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(head_dim), k)
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
causal_mask = torch.triu(
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
)
scores.masked_fill_(causal_mask, float("-inf"))
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
......@@ -291,14 +401,21 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat([a / rearrange(torch.exp(lse - m), 'b h s -> b h s 1')
for a, m in zip(attn_unnorm_block, cummax_block)], dim=-1)
attn_norm = torch.cat(
[
a / rearrange(torch.exp(lse - m), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False):
def get_dropout_fraction(
dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
......@@ -307,52 +424,60 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask
if query_padding_mask is not None:
dropped.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), False)
dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None:
dropped.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), False)
dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool,
device=dropout_mask.device), 1)
causal_mask = torch.triu(
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
)
dropped.masked_fill_(causal_mask, False)
dropped_total = dropped.sum()
query_lengths = (query_padding_mask.sum(dim=-1) if query_padding_mask is not None
else torch.full((batch_size,), seqlen_q, device=dropout_mask.device))
key_lengths = (key_padding_mask.sum(dim=-1) if key_padding_mask is not None
else torch.full((batch_size,), seqlen_k, device=dropout_mask.device))
query_lengths = (
query_padding_mask.sum(dim=-1)
if query_padding_mask is not None
else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)
)
key_lengths = (
key_padding_mask.sum(dim=-1)
if key_padding_mask is not None
else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)
)
if not causal:
numel_per_batch = query_lengths * key_lengths
else:
numel_per_batch = torch.where(
query_lengths <= key_lengths,
query_lengths * (query_lengths + 1) / 2,
query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2)
query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2),
)
return dropped_total / (numel_per_batch.sum() * nheads)
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 16
nheads = 9
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
requires_grad=True)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
)
......@@ -362,16 +487,25 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)[:, :, :seqlen, :seqlen]
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
None, None, dropout_p > 0.0, causal=causal)
attn = normalize_flash_attn_S(
attn_unnorm,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
None,
None,
dropout_p > 0.0,
causal=causal,
)
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal,
upcast=False, reorder_ops=True)
out_pt, attn_pt = attention_qkvpacked_ref(
qkv, None, dropout_p, dropout_mask, causal=causal, upcast=False, reorder_ops=True
)
# v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
# if causal:
......@@ -390,30 +524,30 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
# o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
# o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
# do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv, = torch.autograd.grad(out, qkv, g)
dqkv_ref, = torch.autograd.grad(out_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(out_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
(dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
......@@ -427,29 +561,29 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 6
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
requires_grad=True)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
......@@ -466,41 +600,57 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)[:, :, :seqlen, :seqlen]
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0,
causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
attn = normalize_flash_attn_S(
attn_unnorm,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p > 0.0,
causal=causal,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, key_padding_mask, key_padding_mask, causal=causal
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
out_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
out_ref, attn_ref = attention_qkvpacked_ref(
qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal
)
out_pt, attn_pt = attention_qkvpacked_ref(
qkv,
key_padding_mask,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv_unpad, = torch.autograd.grad(out, qkv_unpad, g)
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(out_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(out_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
......@@ -514,28 +664,45 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize('kvpacked', [True, False])
@pytest.mark.parametrize("kvpacked", [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mha"])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 16
......@@ -544,13 +711,16 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func(
......@@ -572,58 +742,101 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep,
None, None, dropout_p > 0.0, causal=causal)
attn = normalize_flash_attn_S(
attn_unnorm, q, k_rep, v_rep, None, None, dropout_p > 0.0, causal=causal
)
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask,
causal=causal)
out_pt, attn_pt = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
out_ref, attn_ref = attention_kvpacked_ref(
q, kv, None, None, dropout_p, dropout_mask, causal=causal
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
None,
None,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
else:
out_ref, attn_ref = attention_ref(q, k, v, None, None, dropout_p, dropout_mask,
causal=causal)
out_pt, attn_pt = attention_ref(q, k, v, None, None, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
out_ref, attn_ref = attention_ref(
q, k, v, None, None, dropout_p, dropout_mask, causal=causal
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
None,
None,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if kvpacked:
dq, dkv, = torch.autograd.grad(out, (q, kv), g)
(
dq,
dkv,
) = torch.autograd.grad(out, (q, kv), g)
dk, dv = dkv.unbind(2)
dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g)
(
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g)
(
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
dq, dk, dv, = torch.autograd.grad(out, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
......@@ -639,26 +852,44 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize('kvpacked', [True, False])
@pytest.mark.parametrize("kvpacked", [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype,
kvpacked):
if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 16
......@@ -667,35 +898,73 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='random')
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if kvpacked:
(q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv,
output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv(
q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True
)
(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
return_attn_probs=True,
causal=causal,
)
else:
(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False
)
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
return_attn_probs=True,
causal=causal,
)
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
......@@ -710,64 +979,112 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep,
query_padding_mask, key_padding_mask,
dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask,
key_padding_mask, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
attn = normalize_flash_attn_S(
attn_unnorm,
q,
k_rep,
v_rep,
query_padding_mask,
key_padding_mask,
dropout_p > 0.0,
causal=causal,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, query_padding_mask, key_padding_mask, causal=causal
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
out_ref, attn_ref = attention_kvpacked_ref(
q, kv, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
else:
out_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
out_ref, attn_ref = attention_ref(
q, k, v, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if kvpacked:
dq_unpad, dkv_unpad, = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
(
dq_unpad,
dkv_unpad,
) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g)
(
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g)
(
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
(
dq_unpad,
dk_unpad,
dv_unpad,
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
dq = dq_pad_fn(dq_unpad)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
......@@ -783,32 +1100,31 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
@pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("seqlen", [128])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
nheads = 4
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
requires_grad=True)
out0, lse0, _ = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
out0, lse0, _ = flash_attn_qkvpacked_func(qkv, dropout_p, return_attn_probs=True, causal=causal)
g = torch.randn_like(out0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv0, = torch.autograd.grad(out0, qkv, g)
(dqkv0,) = torch.autograd.grad(out0, qkv, g)
# Numerical error if we just do any arithmetic on dq
dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item()
......@@ -821,35 +1137,40 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert torch.equal(lse, lse0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv, = torch.autograd.grad(out, qkv, g)
(dqkv,) = torch.autograd.grad(out, qkv, g)
dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol)
if not dq_equal:
dq0 = dqkv0[:, :, 0]
dq = dqkv[:, :, 0]
print(f'Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}')
print(
f"Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}"
)
assert dq_equal
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64])
@pytest.mark.parametrize("d", [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
@pytest.mark.parametrize('seqlen', [1, 2, 5, 17, 128])
@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0.
"""
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 5
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2)]
k, v = [
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
for _ in range(2)
]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
......@@ -866,38 +1187,45 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (q_pt.grad - q_ref.grad).abs().max().item() + 1e-3
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (k_pt.grad - k_ref.grad).abs().max().item() + 1e-3
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (v_pt.grad - v_ref.grad).abs().max().item() + 1e-3
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
q_pt.grad - q_ref.grad
).abs().max().item() + 1e-3
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
k_pt.grad - k_ref.grad
).abs().max().item() + 1e-3
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
v_pt.grad - v_ref.grad
).abs().max().item() + 1e-3
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
@pytest.mark.parametrize("d", [64, 128])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
""" We previously had a bug where we were using the wrong strides of dout, which shows up
"""We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous.
"""
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 2
q, k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda",
requires_grad=True)
for _ in range(3)]
q, k, v = [
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
for _ in range(3)
]
out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
# So g is not contiguous
g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
......@@ -914,28 +1242,34 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref = rearrange(out_ref, "b s ... -> s b ...")
out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (q_pt.grad - q_ref.grad).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (k_pt.grad - k_ref.grad).abs().max().item()
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (v_pt.grad - v_ref.grad).abs().max().item()
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
q_pt.grad - q_ref.grad
).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
k_pt.grad - k_ref.grad
).abs().max().item()
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
v_pt.grad - v_ref.grad
).abs().max().item()
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64])
@pytest.mark.parametrize("d", [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0 or varlen.
"""
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
nheads = 5
......
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)
@pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize('rotary_fraction', [1.0, 0.5])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
@pytest.mark.parametrize('inplace', [False, True])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
rtol = 1e-3
......@@ -23,12 +23,13 @@ def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
nheads = 4
seqlen = 217
headdim = 128
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda',
requires_grad=True)
x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True
)
x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device='cuda')
angle = torch.randn(seqlen, rotary_dim // 2, device="cuda")
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment