Unverified Commit 94f44b88 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Update fa3 interface and add unit test (#9150)

parent 3b3b3baf
...@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor? scheduler_metadata," " Tensor? scheduler_metadata,"
" int num_splits," " int num_splits,"
" bool? pack_gqa," " bool? pack_gqa,"
" int sm_margin) -> Tensor[]"); " int sm_margin,"
" Tensor? sinks) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
} }
......
...@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd( ...@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd(
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1) std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits, int num_splits,
std::optional<bool> pack_gqa_, std::optional<bool> pack_gqa_,
int const sm_margin); int const sm_margin,
std::optional<const at::Tensor>& sinks_);
...@@ -58,6 +58,7 @@ def flash_attn_with_kvcache( ...@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
pack_gqa=None, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False, return_softmax_lse=False,
sinks=None,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -205,6 +206,7 @@ def flash_attn_with_kvcache( ...@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
num_splits, num_splits,
pack_gqa, pack_gqa,
sm_margin, sm_margin,
sinks,
) )
# return (out, softmax_lse) if return_softmax_lse else out # return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out
...@@ -232,6 +234,7 @@ def flash_attn_varlen_func( ...@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
pack_gqa=None, pack_gqa=None,
sm_margin=0, sm_margin=0,
return_softmax_lse=False, return_softmax_lse=False,
sinks=None,
): ):
if not is_fa3_supported(): if not is_fa3_supported():
raise NotImplementedError( raise NotImplementedError(
...@@ -277,6 +280,7 @@ def flash_attn_varlen_func( ...@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
num_splits=num_splits, num_splits=num_splits,
pack_gqa=pack_gqa, pack_gqa=pack_gqa,
sm_margin=sm_margin, sm_margin=sm_margin,
sinks=sinks,
) )
return (out, softmax_lse, *rest) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py # Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
import itertools import itertools
import math import math
import os from typing import Optional
import pytest import pytest
import torch import torch
...@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True ...@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
# or torch.cuda.get_device_capability("cuda")[0] < 9 # or torch.cuda.get_device_capability("cuda")[0] < 9
# ) # )
DISABLE_SPLIT = True DISABLE_SPLIT = False
DISABLE_PAGEDKV = True DISABLE_PAGEDKV = True
DISABLE_APPENDKV = True DISABLE_APPENDKV = False
DISABLE_LOCAL = True DISABLE_LOCAL = False
DISABLE_SOFTCAP = True DISABLE_SOFTCAP = True
DISABLE_PACKGQA = True DISABLE_PACKGQA = False
DISABLE_FP16 = True DISABLE_FP16 = True
DISABLE_FP8 = True DISABLE_FP8 = True
...@@ -199,6 +199,7 @@ def attention_ref( ...@@ -199,6 +199,7 @@ def attention_ref(
v_descale=None, v_descale=None,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
sink_token_length=0, sink_token_length=0,
sinks: Optional[torch.Tensor] = None,
softcap=0.0, softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
...@@ -271,7 +272,18 @@ def attention_ref( ...@@ -271,7 +272,18 @@ def attention_ref(
scores.masked_fill_(local_mask, float("-inf")) scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None: if attn_bias is not None:
scores = scores + attn_bias scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype) if sinks is None:
attention = torch.softmax(scores, dim=-1).to(v.dtype)
else:
scores_fp32 = scores.to(torch.float32)
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
sinks = rearrange(sinks, "h -> h 1 1")
logits_or_sinks_max = torch.maximum(sinks, logits_max)
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
sinks - logits_or_sinks_max
)
attention = (unnormalized_scores / normalizer).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs # We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV # Otherwise we'll get NaN in dV
if query_padding_mask is not None: if query_padding_mask is not None:
...@@ -459,8 +471,10 @@ def generate_qkv( ...@@ -459,8 +471,10 @@ def generate_qkv(
) )
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("has_sink", [False, True])
# @pytest.mark.parametrize("has_sink", [False])
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
# @pytest.mark.parametrize("new_kv", [True]) # @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize( # @pytest.mark.parametrize(
...@@ -540,6 +554,7 @@ def test_flash_attn_kvcache( ...@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
new_kv, new_kv,
mha_type, mha_type,
dtype, dtype,
has_sink,
): ):
from sgl_kernel.flash_attn import flash_attn_with_kvcache from sgl_kernel.flash_attn import flash_attn_with_kvcache
...@@ -565,6 +580,12 @@ def test_flash_attn_kvcache( ...@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
if has_sink:
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
sinks = None
if dtype == torch.float8_e4m3fn or not is_hopper(): if dtype == torch.float8_e4m3fn or not is_hopper():
# for fp8 and ampere arch, we not support v head dim != qk head dim # for fp8 and ampere arch, we not support v head dim != qk head dim
dv_vals = [d] dv_vals = [d]
...@@ -820,6 +841,7 @@ def test_flash_attn_kvcache( ...@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
qv=qv, qv=qv,
window_size=window_size, window_size=window_size,
key_leftpad=cache_leftpad, key_leftpad=cache_leftpad,
sinks=sinks,
) )
out_pt, _ = attention_ref( out_pt, _ = attention_ref(
q_ro, q_ro,
...@@ -834,6 +856,7 @@ def test_flash_attn_kvcache( ...@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
reorder_ops=True, reorder_ops=True,
key_leftpad=cache_leftpad, key_leftpad=cache_leftpad,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
sinks=sinks,
) )
q = q.to(dtype) q = q.to(dtype)
q_unpad = q_unpad.to(dtype) if varlen_q else None q_unpad = q_unpad.to(dtype) if varlen_q else None
...@@ -888,6 +911,7 @@ def test_flash_attn_kvcache( ...@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
scheduler_metadata=scheduler_metadata, scheduler_metadata=scheduler_metadata,
num_splits=num_splits, num_splits=num_splits,
return_softmax_lse=True, return_softmax_lse=True,
sinks=sinks,
) )
if varlen_q: if varlen_q:
out = output_pad_fn(out) out = output_pad_fn(out)
...@@ -1019,8 +1043,10 @@ def _generate_block_kvcache( ...@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
) )
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("has_sink", [False, True])
# @pytest.mark.parametrize("has_sink", [False])
# @pytest.mark.parametrize("has_qv", [False, True]) # @pytest.mark.parametrize("has_qv", [False, True])
@pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("has_qv", [False])
# @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False, True])
...@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output( ...@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
has_qv, has_qv,
mha_type, mha_type,
dtype, dtype,
has_sink,
): ):
from sgl_kernel.flash_attn import flash_attn_varlen_func from sgl_kernel.flash_attn import flash_attn_varlen_func
...@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output( ...@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
qv_ref = None qv_ref = None
# Put window_size after QKV randn so that window_size changes from test to test # Put window_size after QKV randn so that window_size changes from test to test
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
if has_sink:
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
sinks = None
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
q_descale, k_descale, v_descale = [ q_descale, k_descale, v_descale = [
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
...@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output( ...@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
v_descale=v_descale, v_descale=v_descale,
window_size=window_size, window_size=window_size,
softcap=softcap, softcap=softcap,
sinks=sinks,
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q_ref, q_ref,
...@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output( ...@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
sinks=sinks,
) )
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
...@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output( ...@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
window_size=window_size, window_size=window_size,
softcap=softcap, softcap=softcap,
return_softmax_lse=True, return_softmax_lse=True,
sinks=sinks,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
if query_unused_mask is not None: if query_unused_mask is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment