Unverified Commit 81c89111 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Add test for flash_attn_varlen_func kernel (#5484)

parent 92d1561b
...@@ -296,6 +296,152 @@ def attention_ref( ...@@ -296,6 +296,152 @@ def attention_ref(
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
add_unused_qkv=False,
query_unused_mask=None,
key_unused_mask=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_unused_mask is not None or key_unused_mask is not None:
assert not kvpacked
assert not qkvpacked
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
q,
query_padding_mask,
query_unused_mask,
)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
seqused_q = None
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
k, key_padding_mask, key_unused_mask
)
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
seqused_k = None
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(
dqkv_unpad, indices_q, batch_size, seqlen_q
)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(
dkv_unpad, indices_k, batch_size, seqlen_k
)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(
dk_unpad, indices_k, batch_size, seqlen_k
)
else:
dk_pad_fn = lambda dk_unpad: rearrange(
dk_unpad, "(b s) h d -> b s h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
@pytest.mark.skipif( @pytest.mark.skipif(
not is_fa3_supported(), not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 and above", reason="flash_attn at sgl-kernel is only supported on sm90 and above",
...@@ -855,5 +1001,320 @@ def _generate_block_kvcache( ...@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("has_qv", [False, True])
@pytest.mark.parametrize("has_qv", [False])
# @pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("deterministic", [False])
@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
# @pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("add_unused_qkv", [False, True])
# @pytest.mark.parametrize("add_unused_qkv", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
# @pytest.mark.parametrize("d", COMPILED_HDIMS)
@pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 1),
(1, 3),
(2, 1),
(511, 1),
(3, 513),
(64, 128),
(128, 128),
(256, 256),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(307, 256),
(640, 128),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
def test_flash_attn_varlen_output(
seqlen_q,
seqlen_k,
d,
add_unused_qkv,
causal,
local,
softcap,
deterministic,
has_qv,
mha_type,
dtype,
):
from sgl_kernel.flash_attn import flash_attn_varlen_func
device = "cuda"
# set seed
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
# batch_size = 40
# nheads = 16
batch_size = 9 if seqlen_q <= 2048 else 2
nheads = 6
# batch_size = 2
# nheads = 1
nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
if dtype == torch.float8_e4m3fn:
dv_vals = [d]
for dv in dv_vals:
q_ref = torch.randn(
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
)
if softcap > 0.0:
# Ensure the values of qk are at least within softcap range.
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
k_ref = (
torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
.requires_grad_()
)
v_ref = (
torch.randn(
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
.requires_grad_()
)
if has_qv:
qv_ref = (
torch.randn(
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
)
else:
qv_ref = None
# Put window_size after QKV randn so that window_size changes from test to test
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
if dtype == torch.float8_e4m3fn:
q_descale, k_descale, v_descale = [
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
* 2
for _ in range(3)
]
else:
q_descale, k_descale, v_descale = None, None, None
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
qv = qv_ref.detach() if has_qv else None
query_padding_mask = generate_random_padding_mask(
seqlen_q, batch_size, device, mode="random", zero_lengths=False
)
key_padding_mask = generate_random_padding_mask(
seqlen_k, batch_size, device, mode="random", zero_lengths=True
)
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
if add_unused:
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
attn_mask = torch.logical_and(padding_mask, another_mask)
unused_mask = torch.logical_xor(
torch.logical_or(padding_mask, another_mask), attn_mask
)
else:
attn_mask = padding_mask
unused_mask = None
return attn_mask, unused_mask
query_padding_mask, query_unused_mask = _gen_unused_masks(
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
)
key_padding_mask, key_unused_mask = _gen_unused_masks(
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
)
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
q,
k,
v,
query_padding_mask,
key_padding_mask,
kvpacked=False,
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask,
)
q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
]
out_ref, attn_ref = attention_ref(
q_ref,
k_ref,
v_ref,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv_ref,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q_ref,
k_ref,
v_ref,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv_ref,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
)
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if query_unused_mask is not None:
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2 if softcap == 0.0 else 3
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
out_unpad, lse, *rest = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
causal=causal,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
softcap=softcap,
return_softmax_lse=True,
)
out = output_pad_fn(out_unpad)
if query_unused_mask is not None:
out.masked_fill_(q_zero_masking, 0.0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most 3x the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= rtol * (
out_pt - out_ref
).abs().max().item() + fwd_atol
if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv:
g_unpad = torch.randn_like(out_unpad)
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
)
dq = dq_pad_fn(dq_unpad)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
if key_unused_mask is not None:
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
dk.masked_fill_(k_zero_masking, 0.0)
dv.masked_fill_(k_zero_masking, 0.0)
if query_unused_mask is not None:
dq.masked_fill_(q_zero_masking, 0.0)
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
# assert dq_accum.abs().max().item() == 0.0
g = output_pad_fn(g_unpad)
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv:
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dq - dq_ref).abs().max().item() <= rtol * (
dq_pt - dq_ref
).abs().max().item() + dq_atol
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dk - dk_ref).abs().max().item() <= rtol * (
dk_pt - dk_ref
).abs().max().item() + dk_atol
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dv - dv_ref).abs().max().item() <= rtol * (
dv_pt - dv_ref
).abs().max().item() + dv_atol
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment