Unverified Commit 94f44b88 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Update fa3 interface and add unit test (#9150)

parent 3b3b3baf
......@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
" int sm_margin,"
" Tensor? sinks) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}
......
......@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd(
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin);
int const sm_margin,
std::optional<const at::Tensor>& sinks_);
......@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
......@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
num_splits,
pack_gqa,
sm_margin,
sinks,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
......@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
):
if not is_fa3_supported():
raise NotImplementedError(
......@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
sinks=sinks,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
import itertools
import math
import os
from typing import Optional
import pytest
import torch
......@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
# or torch.cuda.get_device_capability("cuda")[0] < 9
# )
DISABLE_SPLIT = True
DISABLE_SPLIT = False
DISABLE_PAGEDKV = True
DISABLE_APPENDKV = True
DISABLE_LOCAL = True
DISABLE_APPENDKV = False
DISABLE_LOCAL = False
DISABLE_SOFTCAP = True
DISABLE_PACKGQA = True
DISABLE_PACKGQA = False
DISABLE_FP16 = True
DISABLE_FP8 = True
......@@ -199,6 +199,7 @@ def attention_ref(
v_descale=None,
window_size=(-1, -1), # -1 means infinite window size
sink_token_length=0,
sinks: Optional[torch.Tensor] = None,
softcap=0.0,
upcast=True,
reorder_ops=False,
......@@ -271,7 +272,18 @@ def attention_ref(
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
if sinks is None:
attention = torch.softmax(scores, dim=-1).to(v.dtype)
else:
scores_fp32 = scores.to(torch.float32)
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
sinks = rearrange(sinks, "h -> h 1 1")
logits_or_sinks_max = torch.maximum(sinks, logits_max)
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
sinks - logits_or_sinks_max
)
attention = (unnormalized_scores / normalizer).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
......@@ -459,8 +471,10 @@ def generate_qkv(
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("has_sink", [False, True])
# @pytest.mark.parametrize("has_sink", [False])
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize(
......@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
new_kv,
mha_type,
dtype,
has_sink,
):
from sgl_kernel.flash_attn import flash_attn_with_kvcache
......@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
assert nheads % nheads_k == 0
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
if has_sink:
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
sinks = None
if dtype == torch.float8_e4m3fn or not is_hopper():
# for fp8 and ampere arch, we not support v head dim != qk head dim
dv_vals = [d]
......@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
qv=qv,
window_size=window_size,
key_leftpad=cache_leftpad,
sinks=sinks,
)
out_pt, _ = attention_ref(
q_ro,
......@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
reorder_ops=True,
key_leftpad=cache_leftpad,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
sinks=sinks,
)
q = q.to(dtype)
q_unpad = q_unpad.to(dtype) if varlen_q else None
......@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
return_softmax_lse=True,
sinks=sinks,
)
if varlen_q:
out = output_pad_fn(out)
......@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("has_sink", [False, True])
# @pytest.mark.parametrize("has_sink", [False])
# @pytest.mark.parametrize("has_qv", [False, True])
@pytest.mark.parametrize("has_qv", [False])
# @pytest.mark.parametrize("deterministic", [False, True])
......@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
has_qv,
mha_type,
dtype,
has_sink,
):
from sgl_kernel.flash_attn import flash_attn_varlen_func
......@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
qv_ref = None
# Put window_size after QKV randn so that window_size changes from test to test
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
if has_sink:
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
sinks = None
if dtype == torch.float8_e4m3fn:
q_descale, k_descale, v_descale = [
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
......@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
v_descale=v_descale,
window_size=window_size,
softcap=softcap,
sinks=sinks,
)
out_pt, attn_pt = attention_ref(
q_ref,
......@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
upcast=False,
reorder_ops=True,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
sinks=sinks,
)
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
......@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
window_size=window_size,
softcap=softcap,
return_softmax_lse=True,
sinks=sinks,
)
out = output_pad_fn(out_unpad)
if query_unused_mask is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment