Commit 0e8c46ae authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on test files

parent 7fcd3e6a
import math import math
from functools import partial from functools import partial
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.ops.fused_dense import FusedDense, FusedMLP from flash_attn.ops.fused_dense import FusedDense, FusedMLP
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('return_residual', [False, True]) @pytest.mark.parametrize("return_residual", [False, True])
@pytest.mark.parametrize('has_bias', [True, False]) @pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize("out_features", [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096]) @pytest.mark.parametrize("in_features", [1024, 4096])
def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype): def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
device = 'cuda' device = "cuda"
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, x_pt = torch.randn(
requires_grad=True) batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
model = FusedDense(in_features, out_features, bias=has_bias, return_residual=return_residual, model = FusedDense(
device=device, dtype=dtype) in_features,
out_features,
bias=has_bias,
return_residual=return_residual,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
if has_bias: if has_bias:
...@@ -37,10 +42,16 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, ...@@ -37,10 +42,16 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
out = model(x) out = model(x)
else: else:
out, x_copy = model(x) out, x_copy = model(x)
x_copy = (x_copy[..., :out_features] if out_features < in_features x_copy = (
else F.pad(x_copy, (0, out_features - in_features))) x_copy[..., :out_features]
x_pt_copy = (x_pt[..., :out_features] if out_features < in_features if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features))) else F.pad(x_copy, (0, out_features - in_features))
)
x_pt_copy = (
x_pt[..., :out_features]
if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features))
)
# Just add some random function of the residual # Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt_copy) out_pt = out_pt + F.gelu(x_pt_copy)
out = out + F.gelu(x_copy) out = out + F.gelu(x_copy)
...@@ -60,43 +71,64 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, ...@@ -60,43 +71,64 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('heuristic', ['auto', -1]) @pytest.mark.parametrize("heuristic", ["auto", -1])
# @pytest.mark.parametrize('heuristic', ['auto']) # @pytest.mark.parametrize('heuristic', ['auto'])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2]) @pytest.mark.parametrize("checkpoint_lvl", [0, 1, 2])
# @pytest.mark.parametrize('checkpoint_lvl', [1]) # @pytest.mark.parametrize('checkpoint_lvl', [1])
@pytest.mark.parametrize('return_residual', [False, True]) @pytest.mark.parametrize("return_residual", [False, True])
# @pytest.mark.parametrize('return_residual', [False]) # @pytest.mark.parametrize('return_residual', [False])
@pytest.mark.parametrize('has_bias2', [True, False]) @pytest.mark.parametrize("has_bias2", [True, False])
@pytest.mark.parametrize('has_bias1', [True, False]) @pytest.mark.parametrize("has_bias1", [True, False])
# @pytest.mark.parametrize('has_bias2', [True]) # @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True]) # @pytest.mark.parametrize('has_bias1', [True])
@pytest.mark.parametrize('activation', ['gelu_approx', 'relu']) @pytest.mark.parametrize("activation", ["gelu_approx", "relu"])
# @pytest.mark.parametrize('activation', ['relu']) # @pytest.mark.parametrize('activation', ['relu'])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize("out_features", [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096]) @pytest.mark.parametrize("in_features", [1024, 4096])
# @pytest.mark.parametrize('out_features', [4096]) # @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024]) # @pytest.mark.parametrize('in_features', [1024])
def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, return_residual, def test_fused_mlp(
checkpoint_lvl, heuristic, dtype): in_features,
device = 'cuda' out_features,
activation,
has_bias1,
has_bias2,
return_residual,
checkpoint_lvl,
heuristic,
dtype,
):
device = "cuda"
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, x_pt = torch.randn(
requires_grad=True) batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, bias=has_bias1, device=device, model_pt_fc1 = torch.nn.Linear(
dtype=dtype) in_features, out_features, bias=has_bias1, device=device, dtype=dtype
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, )
dtype=dtype) model_pt_fc2 = torch.nn.Linear(
model = FusedMLP(in_features, out_features, in_features, activation=activation, out_features, in_features, bias=has_bias2, device=device, dtype=dtype
bias1=has_bias1, bias2=has_bias2, return_residual=return_residual, )
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, model = FusedMLP(
device=device, dtype=dtype) in_features,
out_features,
in_features,
activation=activation,
bias1=has_bias1,
bias2=has_bias2,
return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl,
heuristic=heuristic,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight) model.fc1.weight.copy_(model_pt_fc1.weight)
if has_bias1: if has_bias1:
...@@ -104,8 +136,11 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, ...@@ -104,8 +136,11 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
model.fc2.weight.copy_(model_pt_fc2.weight) model.fc2.weight.copy_(model_pt_fc2.weight)
if has_bias2: if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias) model.fc2.bias.copy_(model_pt_fc2.bias)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' activation_fn = (
else partial(F.relu, inplace=True)) partial(F.gelu, approximate="tanh")
if activation == "gelu_approx"
else partial(F.relu, inplace=True)
)
out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt))) out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
if not return_residual: if not return_residual:
out = model(x) out = model(x)
...@@ -121,13 +156,17 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, ...@@ -121,13 +156,17 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
out_pt.backward(g) out_pt.backward(g)
out.backward(g) out.backward(g)
# The error for relu is higher still # The error for relu is higher still
if activation == 'relu': if activation == "relu":
atol = 1e-1 if dtype == torch.bfloat16 else 5e-2 atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(
model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10
)
if has_bias1: if has_bias1:
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(
model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10
)
if has_bias2: if has_bias2:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
...@@ -3,36 +3,33 @@ ...@@ -3,36 +3,33 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedMLP
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias', [True, False]) @pytest.mark.parametrize("has_bias", [True, False])
# @pytest.mark.parametrize('has_bias', [False]) # @pytest.mark.parametrize('has_bias', [False])
@pytest.mark.parametrize('out_features', [1024]) @pytest.mark.parametrize("out_features", [1024])
@pytest.mark.parametrize('in_features', [4096]) @pytest.mark.parametrize("in_features", [4096])
def test_fused_linear_bias(in_features, out_features, has_bias, sequence_parallel, def test_fused_linear_bias(
world_size, dtype): in_features, out_features, has_bias, sequence_parallel, world_size, dtype
):
assert out_features % world_size == 0 assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -41,77 +38,95 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle ...@@ -41,77 +38,95 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
batch_size = 2 batch_size = 2
seqlen = 512 seqlen = 512
assert batch_size * seqlen % world_size == 0 assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, x_pt = torch.randn(
requires_grad=True) batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
partition_out_features = out_features // world_size partition_out_features = out_features // world_size
model = ColumnParallelLinear(in_features, out_features, model = ColumnParallelLinear(
parallel_state.get_tensor_model_parallel_group(), bias=has_bias, in_features,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) out_features,
parallel_state.get_tensor_model_parallel_group(),
bias=has_bias,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_( model.weight.copy_(
model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features] model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
) )
if has_bias: if has_bias:
model.bias.copy_( model.bias.copy_(
model_pt.bias[rank * partition_out_features:(rank + 1) * partition_out_features] model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
) )
out = model(x) out = model(x)
out_pt = model_pt(x_pt) out_pt = model_pt(x_pt)
assert torch.allclose( assert torch.allclose(
out, out_pt[:, rank * partition_out_features:(rank + 1) * partition_out_features], out,
rtol=rtol, atol=atol out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol,
) )
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out_pt) / 32 g = torch.randn_like(out_pt) / 32
out_pt.backward(g) out_pt.backward(g)
out.backward(g[:, rank * partition_out_features:(rank + 1) * partition_out_features]) out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features])
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol else x_pt.grad,
rtol=rtol,
atol=atol,
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
model.weight.grad, model.weight.grad,
model_pt.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features], model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if has_bias: if has_bias:
assert torch.allclose( assert torch.allclose(
model.bias.grad, model.bias.grad,
model_pt.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features], model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 5 rtol=rtol,
atol=atol * 5,
) )
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias2', [True, False]) @pytest.mark.parametrize("has_bias2", [True, False])
# @pytest.mark.parametrize('has_bias2', [True]) # @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize('out_features', [4096]) @pytest.mark.parametrize("out_features", [4096])
@pytest.mark.parametrize('in_features', [1024]) @pytest.mark.parametrize("in_features", [1024])
def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype): def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
assert out_features % world_size == 0 assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -120,77 +135,103 @@ def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, worl ...@@ -120,77 +135,103 @@ def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, worl
batch_size = 2 batch_size = 2
seqlen = 512 seqlen = 512
assert batch_size * seqlen % world_size == 0 assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, x_pt = torch.randn(
requires_grad=True) batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, model_pt_fc2 = torch.nn.Linear(
dtype=dtype) out_features, in_features, bias=has_bias2, device=device, dtype=dtype
)
partition_out_features = out_features // world_size partition_out_features = out_features // world_size
partition_in_features = in_features // world_size partition_in_features = in_features // world_size
model = ParallelFusedMLP(in_features, out_features, in_features, model = ParallelFusedMLP(
process_group=parallel_state.get_tensor_model_parallel_group(), in_features,
bias2=has_bias2 and rank == 0, out_features,
sequence_parallel=sequence_parallel, in_features,
device=device, dtype=dtype) process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.fc1.weight.copy_( model.fc1.weight.copy_(
model_pt_fc1.weight[rank * partition_out_features:(rank + 1) * partition_out_features] model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
) )
model.fc1.bias.copy_( model.fc1.bias.copy_(
model_pt_fc1.bias[rank * partition_out_features:(rank + 1) * partition_out_features] model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
) )
model.fc2.weight.copy_( model.fc2.weight.copy_(
model_pt_fc2.weight[:, rank * partition_out_features:(rank + 1) * partition_out_features] model_pt_fc2.weight[
:, rank * partition_out_features : (rank + 1) * partition_out_features
]
) )
if has_bias2 and rank == 0: if has_bias2 and rank == 0:
model.fc2.bias.copy_(model_pt_fc2.bias) model.fc2.bias.copy_(model_pt_fc2.bias)
out = model(x) out = model(x)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate="tanh"))
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol else x_pt.grad,
rtol=rtol,
atol=atol,
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
model.fc1.weight.grad, model.fc1.weight.grad,
model_pt_fc1.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features], model_pt_fc1.weight.grad[
rtol=rtol, atol=atol * 10 rank * partition_out_features : (rank + 1) * partition_out_features
],
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.fc1.bias.grad, model.fc1.bias.grad,
model_pt_fc1.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features], model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol, atol=atol * 5 rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.fc2.weight.grad, model.fc2.weight.grad,
model_pt_fc2.weight.grad[:, rank * partition_out_features:(rank + 1) * partition_out_features], model_pt_fc2.weight.grad[
rtol=rtol, atol=atol * 10 :, rank * partition_out_features : (rank + 1) * partition_out_features
],
rtol=rtol,
atol=atol * 10,
) )
if has_bias2 and rank == 0: if has_bias2 and rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn import (
from flash_attn import flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func flash_attn_func,
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func flash_attn_kvpacked_func,
from flash_attn import flash_attn_varlen_func flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
MAX_HEADDIM_SM8x = 192 MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5) is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability('cuda')[0] == 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0) is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability('cuda') == (9, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'): def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ['full', 'random', 'third'] assert mode in ["full", "random", "third"]
if mode == 'full': if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == 'random': elif mode == "random":
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
elif mode == 'third': elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
return padding_mask return padding_mask
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, def generate_qkv(
kvpacked=False, qkvpacked=False): q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
""" """
Arguments: Arguments:
q: (batch_size, seqlen_q, nheads, d) q: (batch_size, seqlen_q, nheads, d)
...@@ -53,22 +57,28 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, ...@@ -53,22 +57,28 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
if query_padding_mask is not None: if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else: else:
q_unpad = rearrange(q, 'b s h d -> (b s) h d') q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, cu_seqlens_q = torch.arange(
device=q_unpad.device) 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size) output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None: if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask) v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else: else:
k_unpad = rearrange(k, 'b s h d -> (b s) h d') k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, 'b s h d -> (b s) h d') v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, cu_seqlens_k = torch.arange(
device=k_unpad.device) 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k max_seqlen_k = seqlen_k
if qkvpacked: if qkvpacked:
...@@ -79,9 +89,17 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, ...@@ -79,9 +89,17 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
if query_padding_mask is not None: if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else: else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size) dqkv_pad_fn = lambda dqkv_unpad: rearrange(
return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn) )
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked: elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2) kv = torch.stack([k, v], dim=2)
...@@ -89,27 +107,57 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, ...@@ -89,27 +107,57 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
if key_padding_mask is not None: if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else: else:
dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size) dkv_pad_fn = lambda dkv_unpad: rearrange(
return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, )
q.detach().requires_grad_(), kv.detach().requires_grad_(), return (
output_pad_fn, dq_pad_fn, dkv_pad_fn) q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else: else:
dq_pad_fn = output_pad_fn dq_pad_fn = output_pad_fn
if key_padding_mask is not None: if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else: else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size) dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), return (
v_unpad.detach().requires_grad_(), q_unpad.detach().requires_grad_(),
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, k_unpad.detach().requires_grad_(),
q.detach().requires_grad_(), k.detach().requires_grad_(), v_unpad.detach().requires_grad_(),
v.detach().requires_grad_(), cu_seqlens_q,
output_pad_fn, dq_pad_fn, dk_pad_fn) cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, def attention_ref(
dropout_mask=None, causal=False, upcast=True, reorder_ops=False): q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
""" """
Arguments: Arguments:
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
...@@ -136,14 +184,16 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -136,14 +184,16 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1] d = q.shape[-1]
if not reorder_ops: if not reorder_ops:
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else: else:
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal: if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) causal_mask = torch.triu(
scores.masked_fill_(causal_mask, float('-inf')) torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
)
scores.masked_fill_(causal_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1) attention = torch.softmax(scores, dim=-1)
dropout_scaling = 1.0 / (1 - dropout_p) dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
...@@ -152,25 +202,59 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -152,25 +202,59 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
attention_drop = attention.masked_fill(~dropout_mask, 0.0) attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else: else:
attention_drop = attention attention_drop = attention
output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling) output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None: if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0) output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
attention = attention.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(q, kv, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, def attention_kvpacked_ref(
dropout_mask=None, causal=False, upcast=True, reorder_ops=False): q,
return attention_ref(q, kv[:, :, 0], kv[:, :, 1], query_padding_mask, kv,
key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal, query_padding_mask=None,
reorder_ops=reorder_ops) key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
return attention_ref(
q,
kv[:, :, 0],
kv[:, :, 1],
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
reorder_ops=reorder_ops,
)
def attention_qkvpacked_ref(qkv, key_padding_mask=None, dropout_p=0.0, def attention_qkvpacked_ref(
dropout_mask=None, causal=False, upcast=True, reorder_ops=False): qkv,
return attention_ref(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask, key_padding_mask=None,
key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal, dropout_p=0.0,
reorder_ops=reorder_ops) dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
reorder_ops=reorder_ops,
)
def generate_sparsity_mask(seqlen, sparsity=0.3): def generate_sparsity_mask(seqlen, sparsity=0.3):
...@@ -182,7 +266,7 @@ def generate_sparsity_mask(seqlen, sparsity=0.3): ...@@ -182,7 +266,7 @@ def generate_sparsity_mask(seqlen, sparsity=0.3):
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow, ncol = seqlen // 16, seqlen // 256 nrow, ncol = seqlen // 16, seqlen // 256
mask = torch.rand(nrow, ncol, device='cuda') < sparsity mask = torch.rand(nrow, ncol, device="cuda") < sparsity
return mask return mask
...@@ -201,22 +285,23 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask ...@@ -201,22 +285,23 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
q, k, v = qkv.float().unbind(dim=2) q, k, v = qkv.float().unbind(dim=2)
d = qkv.shape[-1] d = qkv.shape[-1]
seqlen = qkv.shape[1] seqlen = qkv.shape[1]
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
blockmask = repeat(blockmask, 's_16 s_256 -> (s_16 16) (s_256 256)') blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
blockmask = blockmask[:seqlen, :seqlen] blockmask = blockmask[:seqlen, :seqlen]
scores.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), float('-inf')) scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
attention = torch.softmax(scores, dim=-1) attention = torch.softmax(scores, dim=-1)
attention = attention.masked_fill(rearrange(~attn_mask, 'b s -> b 1 s 1'), 0.0) attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
attention = attention.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), 0.0) attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v) output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1 1'), 0) output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, head_dim, is_dropout, def convert_flash_attn_S_to_softmax(
causal=False): S, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False
):
"""FlashAttention stores the S matrix in a different way. """FlashAttention stores the S matrix in a different way.
Arguments: Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
...@@ -229,12 +314,27 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea ...@@ -229,12 +314,27 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16 mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(S, 'b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)', S_flat = rearrange(
blocksize_m=blocksize_m, blocksize_n=blocksize_n) S,
S_converted = rearrange(S_flat, 'b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)', "b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
mmas_n=mmas_n, warps_n=warps_n, eight=8, c0=2, c1=2, c2=2, four=4) blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
if causal: if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1) causal_mask = torch.triu(
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1
)
S_converted.masked_fill_(causal_mask, 0.0) S_converted.masked_fill_(causal_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values # Need to zero out things not in attention_mask in case S was initialized with random values
...@@ -245,14 +345,14 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea ...@@ -245,14 +345,14 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
else: else:
query_padding_mask = query_padding_mask[:, :seqlen_q] query_padding_mask = query_padding_mask[:, :seqlen_q]
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None: if key_padding_mask is not None:
if seqlen_k_og < seqlen_k: if seqlen_k_og < seqlen_k:
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
else: else:
key_padding_mask = key_padding_mask[:, :seqlen_k] key_padding_mask = key_padding_mask[:, :seqlen_k]
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0) S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
if seqlen_q_og < seqlen_q: if seqlen_q_og < seqlen_q:
S_converted = S_converted[:, :, :seqlen_q_og, :] S_converted = S_converted[:, :, :seqlen_q_og, :]
else: else:
...@@ -264,8 +364,16 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea ...@@ -264,8 +364,16 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
return S_converted return S_converted
def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None, def normalize_flash_attn_S(
is_dropout=False, causal=False): attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
is_dropout=False,
causal=False,
):
""" """
Arguments: Arguments:
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
...@@ -278,12 +386,14 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa ...@@ -278,12 +386,14 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape _, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1] seqlen_k = k.shape[1]
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(head_dim), k) scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal: if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) causal_mask = torch.triu(
scores.masked_fill_(causal_mask, float('-inf')) torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
)
scores.masked_fill_(causal_mask, float("-inf"))
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1) scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
...@@ -291,14 +401,21 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa ...@@ -291,14 +401,21 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat([a / rearrange(torch.exp(lse - m), 'b h s -> b h s 1') attn_norm = torch.cat(
for a, m in zip(attn_unnorm_block, cummax_block)], dim=-1) [
a / rearrange(torch.exp(lse - m), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None: if query_padding_mask is not None:
attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype) return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False): def get_dropout_fraction(
dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False
):
""" """
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q) query_padding_mask: (batch_size, seqlen_q)
...@@ -307,52 +424,60 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask ...@@ -307,52 +424,60 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask dropped = ~dropout_mask
if query_padding_mask is not None: if query_padding_mask is not None:
dropped.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), False) dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None: if key_padding_mask is not None:
dropped.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), False) dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if causal: if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, causal_mask = torch.triu(
device=dropout_mask.device), 1) torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
)
dropped.masked_fill_(causal_mask, False) dropped.masked_fill_(causal_mask, False)
dropped_total = dropped.sum() dropped_total = dropped.sum()
query_lengths = (query_padding_mask.sum(dim=-1) if query_padding_mask is not None query_lengths = (
else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)) query_padding_mask.sum(dim=-1)
key_lengths = (key_padding_mask.sum(dim=-1) if key_padding_mask is not None if query_padding_mask is not None
else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)) else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)
)
key_lengths = (
key_padding_mask.sum(dim=-1)
if key_padding_mask is not None
else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)
)
if not causal: if not causal:
numel_per_batch = query_lengths * key_lengths numel_per_batch = query_lengths * key_lengths
else: else:
numel_per_batch = torch.where( numel_per_batch = torch.where(
query_lengths <= key_lengths, query_lengths <= key_lengths,
query_lengths * (query_lengths + 1) / 2, query_lengths * (query_lengths + 1) / 2,
query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2) query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2),
) )
return dropped_total / (numel_per_batch.sum() * nheads) return dropped_total / (numel_per_batch.sum() * nheads)
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97]) # @pytest.mark.parametrize('seqlen', [97])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.17]) # @pytest.mark.parametrize('dropout_p', [0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 16
nheads = 9 nheads = 9
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
out, lse, S_dmask = flash_attn_qkvpacked_func( out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal qkv, dropout_p, return_attn_probs=True, causal=causal
) )
...@@ -362,16 +487,25 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -362,16 +487,25 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)[:, :, :seqlen, :seqlen] )[:, :, :seqlen, :seqlen]
dropout_mask = S_dmask_converted >= 0 dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs() attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], attn = normalize_flash_attn_S(
None, None, dropout_p > 0.0, causal=causal) attn_unnorm,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
None,
None,
dropout_p > 0.0,
causal=causal,
)
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}') print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal) out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal, out_pt, attn_pt = attention_qkvpacked_ref(
upcast=False, reorder_ops=True) qkv, None, dropout_p, dropout_mask, causal=causal, upcast=False, reorder_ops=True
)
# v = qkv[:, :, 2].float() # v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
# if causal: # if causal:
...@@ -390,30 +524,30 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -390,30 +524,30 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:]) # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
# o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:]) # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
# o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :]) # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0: if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
# do_o = (g.float() * out.float()).sum(-1) # do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv, = torch.autograd.grad(out, qkv, g) (dqkv,) = torch.autograd.grad(out, qkv, g)
dqkv_ref, = torch.autograd.grad(out_ref, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(out_pt, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
...@@ -427,29 +561,29 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -427,29 +561,29 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 5 batch_size = 5
nheads = 6 nheads = 6
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
...@@ -466,41 +600,57 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -466,41 +600,57 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)[:, :, :seqlen, :seqlen] )[:, :, :seqlen, :seqlen]
dropout_mask = S_dmask_converted >= 0 dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs() attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], attn = normalize_flash_attn_S(
key_padding_mask, key_padding_mask, dropout_p > 0.0, attn_unnorm,
causal=causal) qkv[:, :, 0],
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, qkv[:, :, 1],
causal=causal).item() qkv[:, :, 2],
print(f'Actual dropout fraction: {dropout_fraction}') key_padding_mask,
key_padding_mask,
dropout_p > 0.0,
causal=causal,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, key_padding_mask, key_padding_mask, causal=causal
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, out_ref, attn_ref = attention_qkvpacked_ref(
causal=causal) qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal
out_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, )
causal=causal, upcast=False, reorder_ops=True) out_pt, attn_pt = attention_qkvpacked_ref(
print(f'Output max diff: {(out - out_ref).abs().max().item()}') qkv,
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') key_padding_mask,
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') dropout_p,
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0: if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv_unpad, = torch.autograd.grad(out, qkv_unpad, g) (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad) dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(out_ref, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(out_pt, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
...@@ -514,28 +664,45 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -514,28 +664,45 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize('kvpacked', [True, False]) @pytest.mark.parametrize("kvpacked", [True, False])
# @pytest.mark.parametrize('kvpacked', [False]) # @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mha"]) # @pytest.mark.parametrize('mha_type', ["mha"])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) @pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked): def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 16
...@@ -544,13 +711,16 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -544,13 +711,16 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked: if kvpacked:
kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, kv = torch.randn(
requires_grad=True) batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else: else:
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, k = torch.randn(
requires_grad=True) batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, )
requires_grad=True) v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
if kvpacked: if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func( out, lse, S_dmask = flash_attn_kvpacked_func(
...@@ -572,58 +742,101 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -572,58 +742,101 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
else: else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep, attn = normalize_flash_attn_S(
None, None, dropout_p > 0.0, causal=causal) attn_unnorm, q, k_rep, v_rep, None, None, dropout_p > 0.0, causal=causal
)
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}') print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
if kvpacked: if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask, out_ref, attn_ref = attention_kvpacked_ref(
causal=causal) q, kv, None, None, dropout_p, dropout_mask, causal=causal
out_pt, attn_pt = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask, )
causal=causal, upcast=False, reorder_ops=True) out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
None,
None,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
else: else:
out_ref, attn_ref = attention_ref(q, k, v, None, None, dropout_p, dropout_mask, out_ref, attn_ref = attention_ref(
causal=causal) q, k, v, None, None, dropout_p, dropout_mask, causal=causal
out_pt, attn_pt = attention_ref(q, k, v, None, None, dropout_p, dropout_mask, )
causal=causal, upcast=False, reorder_ops=True) out_pt, attn_pt = attention_ref(
q,
print(f'Output max diff: {(out - out_ref).abs().max().item()}') k,
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') v,
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') None,
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') None,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0: if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if kvpacked: if kvpacked:
dq, dkv, = torch.autograd.grad(out, (q, kv), g) (
dq,
dkv,
) = torch.autograd.grad(out, (q, kv), g)
dk, dv = dkv.unbind(2) dk, dv = dkv.unbind(2)
dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g) (
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2) dk_ref, dv_ref = dkv_ref.unbind(2)
dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g) (
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2) dk_pt, dv_pt = dkv_pt.unbind(2)
else: else:
dq, dk, dv, = torch.autograd.grad(out, (q, k, v), g) (
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g) dq,
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g) dk,
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') dv,
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') ) = torch.autograd.grad(out, (q, k, v), g)
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') (
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}') dq_ref,
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}') dk_ref,
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}') dv_ref,
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') ) = torch.autograd.grad(out_ref, (q, k, v), g)
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') (
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') dq_pt,
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}') dk_pt,
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}') dv_pt,
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}') ) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
...@@ -639,26 +852,44 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -639,26 +852,44 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize('kvpacked', [True, False]) @pytest.mark.parametrize("kvpacked", [True, False])
# @pytest.mark.parametrize('kvpacked', [False]) # @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) @pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, def test_flash_attn_varlen_output(
kvpacked): seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked
if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: ):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 16
...@@ -667,35 +898,73 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ ...@@ -667,35 +898,73 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked: if kvpacked:
kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, kv = torch.randn(
requires_grad=True) batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else: else:
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, k = torch.randn(
requires_grad=True) batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, )
requires_grad=True) v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode='random') query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='random') key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if kvpacked: if kvpacked:
(q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, (
output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv( q_unpad,
q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True kv_unpad,
) cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q_unpad,
dropout_p, return_attn_probs=True, causal=causal kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
return_attn_probs=True,
causal=causal,
) )
else: else:
(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, (
output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv( q_unpad,
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False k_unpad,
) v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q_unpad,
dropout_p, return_attn_probs=True, causal=causal k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
return_attn_probs=True,
causal=causal,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -710,64 +979,112 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ ...@@ -710,64 +979,112 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
else: else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep, attn = normalize_flash_attn_S(
query_padding_mask, key_padding_mask, attn_unnorm,
dropout_p > 0.0, causal=causal) q,
dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, k_rep,
key_padding_mask, causal=causal).item() v_rep,
print(f'Actual dropout fraction: {dropout_fraction}') query_padding_mask,
key_padding_mask,
dropout_p > 0.0,
causal=causal,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, query_padding_mask, key_padding_mask, causal=causal
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
if kvpacked: if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, out_ref, attn_ref = attention_kvpacked_ref(
dropout_p, dropout_mask, causal=causal) q, kv, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
out_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, )
dropout_p, dropout_mask, out_pt, attn_pt = attention_kvpacked_ref(
causal=causal, upcast=False, reorder_ops=True) q,
kv,
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
else: else:
out_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask, out_ref, attn_ref = attention_ref(
dropout_p, dropout_mask, causal=causal) q, k, v, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
out_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask, )
dropout_p, dropout_mask, out_pt, attn_pt = attention_ref(
causal=causal, upcast=False, reorder_ops=True) q,
k,
print(f'Output max diff: {(out - out_ref).abs().max().item()}') v,
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') query_padding_mask,
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') key_padding_mask,
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') dropout_p,
dropout_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0: if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if kvpacked: if kvpacked:
dq_unpad, dkv_unpad, = torch.autograd.grad(out, (q_unpad, kv_unpad), g) (
dq_unpad,
dkv_unpad,
) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g) (
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2) dk_ref, dv_ref = dkv_ref.unbind(2)
dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g) (
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2) dk_pt, dv_pt = dkv_pt.unbind(2)
else: else:
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) (
dq_unpad,
dk_unpad,
dv_unpad,
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dk = dk_pad_fn(dk_unpad) dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad) dv = dk_pad_fn(dv_unpad)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g) (
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g) dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
dq = dq_pad_fn(dq_unpad) dq = dq_pad_fn(dq_unpad)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}') print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}') print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}') print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}') print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}') print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}') print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
...@@ -783,32 +1100,31 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ ...@@ -783,32 +1100,31 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
@pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [128]) # @pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("seqlen", [128])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
nheads = 4 nheads = 4
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
out0, lse0, _ = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
) )
out0, lse0, _ = flash_attn_qkvpacked_func(qkv, dropout_p, return_attn_probs=True, causal=causal)
g = torch.randn_like(out0) g = torch.randn_like(out0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv0, = torch.autograd.grad(out0, qkv, g) (dqkv0,) = torch.autograd.grad(out0, qkv, g)
# Numerical error if we just do any arithmetic on dq # Numerical error if we just do any arithmetic on dq
dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item() dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item()
...@@ -821,35 +1137,40 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): ...@@ -821,35 +1137,40 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert torch.equal(lse, lse0) assert torch.equal(lse, lse0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv, = torch.autograd.grad(out, qkv, g) (dqkv,) = torch.autograd.grad(out, qkv, g)
dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol) dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol)
if not dq_equal: if not dq_equal:
dq0 = dqkv0[:, :, 0] dq0 = dqkv0[:, :, 0]
dq = dqkv[:, :, 0] dq = dqkv[:, :, 0]
print(f'Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}') print(
f"Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}"
)
assert dq_equal assert dq_equal
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1]) assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2]) assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
@pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64]) @pytest.mark.parametrize("d", [16, 32, 64])
# @pytest.mark.parametrize('d', [16]) # @pytest.mark.parametrize('d', [16])
@pytest.mark.parametrize('seqlen', [1, 2, 5, 17, 128]) @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
# @pytest.mark.parametrize('seqlen', [2]) # @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0. in the case where seqlen % 128 != 0.
""" """
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 2 batch_size = 2
nheads = 5 nheads = 5
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2)] k, v = [
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
for _ in range(2)
]
q.requires_grad_(True) q.requires_grad_(True)
k.requires_grad_(True) k.requires_grad_(True)
v.requires_grad_(True) v.requires_grad_(True)
...@@ -866,38 +1187,45 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): ...@@ -866,38 +1187,45 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
v_ref = v.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref.backward(g) out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}') print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}') print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}') print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}') print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}') print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}') print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (q_pt.grad - q_ref.grad).abs().max().item() + 1e-3 assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (k_pt.grad - k_ref.grad).abs().max().item() + 1e-3 q_pt.grad - q_ref.grad
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (v_pt.grad - v_ref.grad).abs().max().item() + 1e-3 ).abs().max().item() + 1e-3
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
k_pt.grad - k_ref.grad
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) ).abs().max().item() + 1e-3
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
v_pt.grad - v_ref.grad
).abs().max().item() + 1e-3
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128]) @pytest.mark.parametrize("d", [64, 128])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
""" We previously had a bug where we were using the wrong strides of dout, which shows up """We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous. when dout is not contiguous.
""" """
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 5 batch_size = 5
nheads = 2 nheads = 2
q, k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", q, k, v = [
requires_grad=True) torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
for _ in range(3)] for _ in range(3)
]
out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
# So g is not contiguous # So g is not contiguous
g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
...@@ -914,28 +1242,34 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): ...@@ -914,28 +1242,34 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref = rearrange(out_ref, "b s ... -> s b ...") out_ref = rearrange(out_ref, "b s ... -> s b ...")
out_ref.backward(g) out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}') print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}') print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}') print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}') print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}') print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}') print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (q_pt.grad - q_ref.grad).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (k_pt.grad - k_ref.grad).abs().max().item() q_pt.grad - q_ref.grad
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (v_pt.grad - v_ref.grad).abs().max().item() ).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
k_pt.grad - k_ref.grad
@pytest.mark.parametrize('dtype', [torch.float16]) ).abs().max().item()
@pytest.mark.parametrize('causal', [False, True]) assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
v_pt.grad - v_ref.grad
).abs().max().item()
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64]) @pytest.mark.parametrize("d", [16, 32, 64])
# @pytest.mark.parametrize('d', [16]) # @pytest.mark.parametrize('d', [16])
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0 or varlen. in the case where seqlen % 128 != 0 or varlen.
""" """
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
nheads = 5 nheads = 5
......
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)
@pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16])) # @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize('rotary_fraction', [1.0, 0.5]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5]) # @pytest.mark.parametrize('rotary_fraction', [0.5])
@pytest.mark.parametrize('inplace', [False, True]) @pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False]) # @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype): def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
rtol = 1e-3 rtol = 1e-3
...@@ -23,12 +23,13 @@ def test_rotary_single_tensor(inplace, rotary_fraction, dtype): ...@@ -23,12 +23,13 @@ def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
nheads = 4 nheads = 4
seqlen = 217 seqlen = 217
headdim = 128 headdim = 128
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda', x = torch.randn(
requires_grad=True) batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True
)
x_pt = x.detach().clone().requires_grad_() x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim) rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device='cuda') angle = torch.randn(seqlen, rotary_dim // 2, device="cuda")
cos = torch.cos(angle).to(dtype=dtype) cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace) out = apply_rotary_emb_func(x, cos, sin, inplace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment