Unverified Commit 252dc4e1 authored by Johnny's avatar Johnny Committed by GitHub
Browse files

[NVIDIA] FA3/FA4 Fix (#11606)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent cbb5fc2e
...@@ -1071,6 +1071,16 @@ class ServerArgs: ...@@ -1071,6 +1071,16 @@ class ServerArgs:
self.enable_mixed_chunk = False self.enable_mixed_chunk = False
self.disable_radix_cache = True self.disable_radix_cache = True
if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4":
raise ValueError(
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."
)
if self.prefill_attention_backend == "fa4":
logger.warning(
f"FA4 backend only supports page size 128, changing page_size from {self.page_size} to 128."
)
self.page_size = 128
def _handle_page_size(self): def _handle_page_size(self):
if self.page_size is None: if self.page_size is None:
self.page_size = 1 self.page_size = 1
......
...@@ -129,6 +129,11 @@ def is_in_amd_ci(): ...@@ -129,6 +129,11 @@ def is_in_amd_ci():
return get_bool_env_var("SGLANG_IS_IN_CI_AMD") return get_bool_env_var("SGLANG_IS_IN_CI_AMD")
def is_blackwell_system():
"""Return whether it is running on a Blackwell (B200) system."""
return get_bool_env_var("IS_BLACKWELL")
def _use_cached_default_models(model_repo: str): def _use_cached_default_models(model_repo: str):
cache_dir = os.getenv("DEFAULT_MODEL_CACHE_DIR") cache_dir = os.getenv("DEFAULT_MODEL_CACHE_DIR")
if cache_dir and model_repo: if cache_dir and model_repo:
...@@ -151,6 +156,9 @@ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 10 ...@@ -151,6 +156,9 @@ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 10
if is_in_amd_ci(): if is_in_amd_ci():
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3000 DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3000
if is_blackwell_system():
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3000
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None assert url is not None
......
...@@ -91,7 +91,7 @@ FetchContent_Populate(repo-flashinfer) ...@@ -91,7 +91,7 @@ FetchContent_Populate(repo-flashinfer)
FetchContent_Declare( FetchContent_Declare(
repo-flash-attention repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG f9af0c2a1d82ab1812e6987e9338363cc2bf0f8d GIT_TAG ff87110aad048bb8c4e6effea4c563ddae88b0eb
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flash-attention) FetchContent_Populate(repo-flash-attention)
...@@ -100,7 +100,7 @@ FetchContent_Populate(repo-flash-attention) ...@@ -100,7 +100,7 @@ FetchContent_Populate(repo-flash-attention)
FetchContent_Declare( FetchContent_Declare(
repo-flash-attention-origin repo-flash-attention-origin
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d GIT_TAG 04adaf0e9028d4bec7073f69e4dfa3f6d3357189
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flash-attention-origin) FetchContent_Populate(repo-flash-attention-origin)
......
...@@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From flash-attention * From flash-attention
*/ */
m.def( m.def(
"fwd(Tensor! q," "fwd(Tensor q," // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
" Tensor k," " Tensor k," // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged
" Tensor v," " Tensor v," // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged
" Tensor? k_new," " Tensor? k_new," // (b, s_k_new, h_k, d) or (total_k_new, h_k, d)
" Tensor? v_new," " Tensor? v_new," // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv)
" Tensor? q_v," " Tensor? q_v," // (b, s_q, h, dv) or (total_q_new, h, dv)
" Tensor!? out," " Tensor? out," // (b, s_q, h, dv) or (total_q, h, dv)
" Tensor? cu_seqlens_q," " Tensor? cu_seqlens_q," // b+1
" Tensor? cu_seqlens_k," " Tensor? cu_seqlens_k," // b+1
" Tensor? cu_seqlens_k_new," " Tensor? cu_seqlens_k_new," // b+1
" Tensor? seqused_q," " Tensor? seqused_q," // b
" Tensor? seqused_k," " Tensor? seqused_k," // b
" int? max_seqlen_q," " int? max_seqlen_q,"
" int? max_seqlen_k," " int? max_seqlen_k," // TODO: check if needed
" Tensor? page_table," " Tensor? page_table," // (b_k, max_num_pages_per_seq)
" Tensor? kv_batch_idx," " Tensor? kv_batch_idx," // b
" Tensor? leftpad_k," " Tensor? leftpad_k," // b
" Tensor? rotary_cos," " Tensor? rotary_cos," // seqlen_ro x (rotary_dim / 2)
" Tensor? rotary_sin," " Tensor? rotary_sin," // seqlen_ro x (rotary_dim / 2)
" Tensor? seqlens_rotary," " Tensor? seqlens_rotary," // b
" Tensor? q_descale," " Tensor? q_descale," // (b, h_k)
" Tensor? k_descale," " Tensor? k_descale," // (b, h_k)
" Tensor? v_descale," " Tensor? v_descale," // (b, h_k)
" float softmax_scale," " float? softmax_scale," // now optional
" bool is_causal," " bool is_causal,"
" int window_size_left," " int window_size_left,"
" int window_size_right," " int window_size_right,"
" float softcap," " int attention_chunk," // NEW
" float softcap," // promoted to double in C++; schema float is fine
" bool is_rotary_interleaved," " bool is_rotary_interleaved,"
" Tensor? scheduler_metadata," " Tensor? scheduler_metadata," // (b + 1)
" int num_splits," " int num_splits,"
" bool? pack_gqa," " bool? pack_gqa,"
" int sm_margin," " int sm_margin,"
" Tensor? sinks) -> Tensor[]"); " Tensor? sinks"
") -> (Tensor, Tensor, Tensor, Tensor)"); // NEW return type: tuple of 4 tensors
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
} }
......
...@@ -42,45 +42,44 @@ limitations under the License. ...@@ -42,45 +42,44 @@ limitations under the License.
/* /*
* From flash-attention * From flash-attention
*/ */
std::vector<at::Tensor> mha_fwd( std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table. // h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table. // page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>& std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional<at::Tensor> cu_seqlens_q_, // b+1
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional<at::Tensor> cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1 std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1 std::optional<at::Tensor>
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>& std::optional<at::Tensor>
seqused_k_, // b. If given, only this many elements of each batch element's keys are used. seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_, std::optional<int64_t> max_seqlen_q_,
// TODO: check if we need max_seqlen_k // TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_, std::optional<int64_t> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq) std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b std::optional<at::Tensor> leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b std::optional<at::Tensor> seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h) std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k) std::optional<at::Tensor> k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k) std::optional<at::Tensor> v_descale_, // (b, h_k)
float const softmax_scale, std::optional<double> softmax_scale_,
bool is_causal, bool is_causal,
int window_size_left, int64_t window_size_left,
int window_size_right, int64_t window_size_right,
float const softcap, int64_t attention_chunk,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 double softcap,
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1) bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits, std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
int64_t num_splits,
std::optional<bool> pack_gqa_, std::optional<bool> pack_gqa_,
int const sm_margin, int64_t sm_margin,
std::optional<const at::Tensor>& sinks_); std::optional<const at::Tensor>& sinks_); // (h)
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py # Adapted from https://github.com/Dao-AILab/flash-attention/blob/54d8aa6751fc9d5f0357854079261913d5df1f9d/flash_attn/cute/interface.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. # [2025-10-14] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.2.1.
import copy import copy
import gc import gc
import logging import logging
import math import math
from typing import Optional, Tuple from typing import Callable, Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -18,6 +18,7 @@ import cutlass ...@@ -18,6 +18,7 @@ import cutlass
import cutlass.cute as cute import cutlass.cute as cute
import torch import torch
from cutlass.cute.runtime import from_dlpack from cutlass.cute.runtime import from_dlpack
from flash_attn.cute import utils
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
...@@ -26,22 +27,6 @@ def maybe_contiguous(x): ...@@ -26,22 +27,6 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _reason_recompile(compile_key, jit_func):
compile_cache = jit_func.compile_cache
compile_key_map = jit_func.compile_key_map
if not compile_cache:
return "not compiled yet"
for k, v in compile_cache.items():
if k == compile_key:
continue
if len(k) != len(compile_key):
continue
for i in range(len(k)):
if k[i] != compile_key[i]:
return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} "
return "unknown reason"
torch2cute_dtype_map = { torch2cute_dtype_map = {
torch.float16: cutlass.Float16, torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16, torch.bfloat16: cutlass.BFloat16,
...@@ -72,7 +57,11 @@ def _flash_attn_fwd( ...@@ -72,7 +57,11 @@ def _flash_attn_fwd(
num_threads: int = 384, num_threads: int = 384,
pack_gqa: Optional[bool] = None, pack_gqa: Optional[bool] = None,
_compute_capability: Optional[int] = None, _compute_capability: Optional[int] = None,
return_softmax_lse: Optional[bool] = False, score_mod: Callable | None = None,
return_lse: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
buffers: Optional[list[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(t) for t in (q, k, v)] q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
num_head, head_dim = q.shape[-2:] num_head, head_dim = q.shape[-2:]
...@@ -169,6 +158,14 @@ def _flash_attn_fwd( ...@@ -169,6 +158,14 @@ def _flash_attn_fwd(
q_batch_seqlen_shape = ( q_batch_seqlen_shape = (
(batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
) )
lse_shape = (
(batch_size, num_head, seqlen_q)
if cu_seqlens_q is None
else (num_head, total_q)
)
requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
if out is None:
out = torch.empty( out = torch.empty(
*q_batch_seqlen_shape, *q_batch_seqlen_shape,
num_head, num_head,
...@@ -176,16 +173,36 @@ def _flash_attn_fwd( ...@@ -176,16 +173,36 @@ def _flash_attn_fwd(
dtype=out_torch_dtype, dtype=out_torch_dtype,
device=device, device=device,
) )
lse_shape = ( else:
(batch_size, num_head, seqlen_q) expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v)
if cu_seqlens_q is None assert (
else (num_head, total_q) out.shape == expected_out_shape
) ), f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}"
assert (
out.dtype == out_torch_dtype
), f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}"
assert (
out.device == device
), f"out tensor device {out.device} does not match input device {device}"
assert out.is_cuda, "out tensor must be on CUDA device"
if lse is None:
lse = ( lse = (
torch.empty(lse_shape, dtype=torch.float32, device=device) torch.empty(lse_shape, dtype=torch.float32, device=device)
if return_softmax_lse if requires_grad or return_lse
else None else None
) )
elif lse is not None:
assert (
lse.shape == lse_shape
), f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}"
assert (
lse.dtype == torch.float32
), f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32"
assert (
lse.device == device
), f"lse tensor device {lse.device} does not match input device {device}"
assert lse.is_cuda, "lse tensor must be on CUDA device"
dtype = torch2cute_dtype_map[q.dtype] dtype = torch2cute_dtype_map[q.dtype]
q_tensor, k_tensor, v_tensor, o_tensor = [ q_tensor, k_tensor, v_tensor, o_tensor = [
...@@ -242,6 +259,7 @@ def _flash_attn_fwd( ...@@ -242,6 +259,7 @@ def _flash_attn_fwd(
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
if compute_capability == 9: # TODO: tune block size according to hdim if compute_capability == 9: # TODO: tune block size according to hdim
# Perf heuristic from upstream: hdim=128, noncausal, non-local benefits from larger n_block
if head_dim == head_dim_v == 128 and not causal and not local: if head_dim == head_dim_v == 128 and not causal and not local:
n_block_size = 192 n_block_size = 192
if compute_capability == 10: if compute_capability == 10:
...@@ -253,13 +271,34 @@ def _flash_attn_fwd( ...@@ -253,13 +271,34 @@ def _flash_attn_fwd(
): ):
pack_gqa = False pack_gqa = False
if softcap is not None:
assert score_mod is None, "softcap and score_mod cannot be used together"
score_mod = utils.create_softcap_scoremod(softcap)
if score_mod is not None:
is_varlen = (
cu_seqlens_q is not None
or cu_seqlens_k is not None
or seqused_q is not None
or seqused_k is not None
)
if is_varlen:
raise NotImplementedError(
"score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR."
)
cute_buffers = None
if buffers is not None:
cute_buffers = [from_dlpack(buf) for buf in buffers]
compile_key = ( compile_key = (
dtype, dtype,
head_dim, head_dim,
head_dim_v, head_dim_v,
qhead_per_kvhead, qhead_per_kvhead,
causal, causal,
softcap is not None, utils.hash_callable(score_mod) if score_mod is not None else None,
buffers is not None,
lse is None, lse is None,
cu_seqlens_q is None, cu_seqlens_q is None,
cu_seqlens_k is None, cu_seqlens_k is None,
...@@ -276,9 +315,6 @@ def _flash_attn_fwd( ...@@ -276,9 +315,6 @@ def _flash_attn_fwd(
compute_capability, compute_capability,
) )
if compile_key not in _flash_attn_fwd.compile_cache: if compile_key not in _flash_attn_fwd.compile_cache:
logger.info(
f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}"
)
if compute_capability == 9: if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0" assert page_table is None, "paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80( # fa_fwd = FlashAttentionForwardSm80(
...@@ -290,12 +326,14 @@ def _flash_attn_fwd( ...@@ -290,12 +326,14 @@ def _flash_attn_fwd(
is_causal=causal, is_causal=causal,
is_local=local, is_local=local,
pack_gqa=pack_gqa, pack_gqa=pack_gqa,
m_block_size=m_block_size, tile_m=m_block_size,
n_block_size=n_block_size, tile_n=n_block_size,
# num_stages=1, # num_stages=1,
num_stages=2, num_stages=2,
num_threads=num_threads, num_threads=num_threads,
Q_in_regs=False, Q_in_regs=False,
score_mod=score_mod,
has_buffers=buffers is not None,
) )
elif compute_capability == 10: elif compute_capability == 10:
assert page_size in [ assert page_size in [
...@@ -313,12 +351,15 @@ def _flash_attn_fwd( ...@@ -313,12 +351,15 @@ def _flash_attn_fwd(
and not local and not local
and cu_seqlens_q is None and cu_seqlens_q is None
and seqused_q is None, and seqused_q is None,
score_mod=score_mod,
has_buffers=buffers is not None,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x" f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x"
) )
# TODO: check @can_implement # TODO: check @can_implement
# TODO caching for buffers; cute_buffers
_flash_attn_fwd.compile_cache[compile_key] = cute.compile( _flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd, fa_fwd,
q_tensor, q_tensor,
...@@ -333,10 +374,10 @@ def _flash_attn_fwd( ...@@ -333,10 +374,10 @@ def _flash_attn_fwd(
seqused_q_tensor, seqused_q_tensor,
seqused_k_tensor, seqused_k_tensor,
page_table_tensor, page_table_tensor,
softcap,
window_size_left, window_size_left,
window_size_right, window_size_right,
learnable_sink_tensor, learnable_sink_tensor,
cute_buffers,
) )
_flash_attn_fwd.compile_cache[compile_key]( _flash_attn_fwd.compile_cache[compile_key](
q_tensor, q_tensor,
...@@ -351,46 +392,29 @@ def _flash_attn_fwd( ...@@ -351,46 +392,29 @@ def _flash_attn_fwd(
seqused_q_tensor, seqused_q_tensor,
seqused_k_tensor, seqused_k_tensor,
page_table_tensor, page_table_tensor,
softcap,
window_size_left, window_size_left,
window_size_right, window_size_right,
learnable_sink_tensor, learnable_sink_tensor,
cute_buffers,
) )
return out, lse return out, lse
_flash_attn_fwd.compile_cache = {} _flash_attn_fwd.compile_cache = {}
_flash_attn_fwd.compile_key_map = [
"dtype",
"head_dim",
"head_dim_v",
"qhead_per_kvhead",
"causal",
"softcap is not None",
"lse is None",
"cu_seqlens_q is None",
"cu_seqlens_k is None",
"seqused_q is None",
"seqused_k is None",
"page_table is not None",
"window_size_left is not None",
"window_size_right is not None",
"learnable_sink is not None",
"m_block_size",
"n_block_size",
"num_threads",
"pack_gqa",
"compute_capability",
]
def warmup_flash_attn(f): def warmup_flash_attn(f):
""" """
Decorator for flash_attn_varlen_func: Decorator for flash_attn_varlen_func:
- On the first call, run several warmup passes with different flag combinations - On first call, run several warmup passes with different flag combinations:
- Warmups are executed sequentially to minimize peak GPU memory usage * return_softmax_lse in {False, True}
- Does not modify user-provided tensors (clones data) * global noncausal (window_size=(None,None))
- Easy to extend with more compile-key dimensions * causal (window_size=(None,0))
* local sliding window (window_size=(64,64))
* optionally pack_gqa=True if qheads > kvheads and allowed
- No score_mod / softcap (not supported for varlen yet)
- Executes sequentially to minimize peak GPU mem
- Does not modify user tensors (clones)
""" """
done = False done = False
...@@ -399,30 +423,78 @@ def warmup_flash_attn(f): ...@@ -399,30 +423,78 @@ def warmup_flash_attn(f):
def maybe_clone(x): def maybe_clone(x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x.clone() return x.detach().clone() # detach to avoid autograd edges
return copy.deepcopy(x) return copy.deepcopy(x)
return tuple(maybe_clone(a) for a in args), { return tuple(maybe_clone(a) for a in args), {
k: maybe_clone(v) for k, v in kwargs.items() k: maybe_clone(v) for k, v in kwargs.items()
} }
def _infer_heads(args, kwargs):
"""Infer q and kv head counts from arguments."""
# Expect signature: (q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
q = args[0] if len(args) > 0 else kwargs.get("q")
k = args[1] if len(args) > 1 else kwargs.get("k")
try:
qh = int(q.shape[-2])
kvh = int(k.shape[-2])
return qh, kvh
except Exception:
return None, None
def _run_warmups(args, kwargs): def _run_warmups(args, kwargs):
"""Run warmup calls sequentially and release memory after each.""" """Run warmup calls sequentially and release memory after each."""
base_args, base_kwargs = _clone_args(args, kwargs) base_args, base_kwargs = _clone_args(args, kwargs)
# Warmup combinations for return_softmax_lse and causal qh, kvh = _infer_heads(base_args, base_kwargs)
combos = [ can_pack_gqa = (
dict(return_softmax_lse=False, causal=False), qh is not None and kvh is not None and qh % kvh == 0 and qh // kvh > 1
dict(return_softmax_lse=False, causal=True), )
dict(return_softmax_lse=True, causal=False), has_page_table = (
dict(return_softmax_lse=True, causal=True), "page_table" in base_kwargs and base_kwargs["page_table"] is not None
)
# Window presets covering global, causal, and local
window_presets = [
(None, None), # global noncausal
(None, 0), # causal
(64, 64), # local sliding window
] ]
lse_flags = [False, True]
# Base combo list
combos = []
for ws in window_presets:
for return_lse_flag in lse_flags:
combos.append(dict(window_size=ws, return_softmax_lse=return_lse_flag))
# Optionally add a pack_gqa=True variant (FA4 may disable it internally for some varlen shapes/SMs)
if can_pack_gqa:
for ws in window_presets:
combos.append(
dict(window_size=ws, return_softmax_lse=False, pack_gqa=True)
)
# If page_table is present, warm one combo with it (page_table in compile key for SM100)
if has_page_table:
combos.append(dict(window_size=(None, None), return_softmax_lse=False))
# Run sequentially
for combo in combos: for combo in combos:
wa, wk = _clone_args(base_args, base_kwargs) wa, wk = _clone_args(base_args, base_kwargs)
# Keep user-provided softcap/score_mod OUT (varlen+score_mod unsupported)
wk.pop("score_mod", None)
if "softcap" in wk and wk["softcap"]:
wk["softcap"] = 0.0
# Apply combo
wk.update(combo) wk.update(combo)
with torch.cuda.stream(torch.cuda.current_stream()): with torch.cuda.stream(torch.cuda.current_stream()):
try:
f(*wa, **wk) f(*wa, **wk)
except Exception as e:
# Some combos can be invalid for specific head dims / arch. Ignore and continue.
logger.debug("Warmup combo skipped: %s", e)
del wa, wk del wa, wk
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -430,7 +502,9 @@ def warmup_flash_attn(f): ...@@ -430,7 +502,9 @@ def warmup_flash_attn(f):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
nonlocal done nonlocal done
if not done: if not done:
logger.info("Running flash_attn_varlen_func warmup passes...") logger.info(
"Running FA4 warmup (global/causal/local, LSE on/off, optional GQA pack)..."
)
_run_warmups(args, kwargs) _run_warmups(args, kwargs)
done = True done = True
return f(*args, **kwargs) return f(*args, **kwargs)
...@@ -472,7 +546,7 @@ def flash_attn_varlen_func( ...@@ -472,7 +546,7 @@ def flash_attn_varlen_func(
learnable_sink=learnable_sink, learnable_sink=learnable_sink,
softcap=softcap, softcap=softcap,
pack_gqa=pack_gqa, pack_gqa=pack_gqa,
return_softmax_lse=return_softmax_lse, return_lse=return_softmax_lse,
) )
return (out, lse) if return_softmax_lse else out return (out, lse) if return_softmax_lse else out
...@@ -45,7 +45,7 @@ def flash_attn_with_kvcache( ...@@ -45,7 +45,7 @@ def flash_attn_with_kvcache(
qv=None, qv=None,
rotary_cos=None, rotary_cos=None,
rotary_sin=None, rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None,
...@@ -59,6 +59,7 @@ def flash_attn_with_kvcache( ...@@ -59,6 +59,7 @@ def flash_attn_with_kvcache(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True, rotary_interleaved=True,
scheduler_metadata=None, scheduler_metadata=None,
...@@ -137,6 +138,7 @@ def flash_attn_with_kvcache( ...@@ -137,6 +138,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
softcap: float. Anything > 0 activates softcapping attention. softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
...@@ -216,6 +218,7 @@ def flash_attn_with_kvcache( ...@@ -216,6 +218,7 @@ def flash_attn_with_kvcache(
] ]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens) rotary_seqlens = maybe_contiguous(rotary_seqlens)
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q, q,
...@@ -245,6 +248,7 @@ def flash_attn_with_kvcache( ...@@ -245,6 +248,7 @@ def flash_attn_with_kvcache(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
attention_chunk,
softcap, softcap,
rotary_interleaved, rotary_interleaved,
scheduler_metadata, scheduler_metadata,
...@@ -263,10 +267,11 @@ def flash_attn_varlen_func( ...@@ -263,10 +267,11 @@ def flash_attn_varlen_func(
v, v,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q=None,
max_seqlen_k, max_seqlen_k=None,
seqused_q=None, seqused_q=None,
seqused_k=None, seqused_k=None,
page_table=None,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
qv=None, qv=None,
...@@ -274,6 +279,7 @@ def flash_attn_varlen_func( ...@@ -274,6 +279,7 @@ def flash_attn_varlen_func(
k_descale=None, k_descale=None,
v_descale=None, v_descale=None,
window_size=(-1, -1), window_size=(-1, -1),
attention_chunk=0,
softcap=0.0, softcap=0.0,
num_splits=1, num_splits=1,
pack_gqa=None, pack_gqa=None,
...@@ -293,25 +299,18 @@ def flash_attn_varlen_func( ...@@ -293,25 +299,18 @@ def flash_attn_varlen_func(
q, q,
k, k,
v, v,
cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
# max_seqlen_q,
# max_seqlen_k,
seqused_q=seqused_q, seqused_q=seqused_q,
seqused_k=seqused_k, seqused_k=seqused_k,
page_table=page_table,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
# qv=qv,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
window_size=window_size, window_size=window_size,
softcap=softcap, softcap=softcap,
# num_splits=num_splits,
pack_gqa=pack_gqa, pack_gqa=pack_gqa,
# sm_margin=sm_margin,
return_softmax_lse=return_softmax_lse,
learnable_sink=sinks, learnable_sink=sinks,
return_softmax_lse=return_softmax_lse,
) )
if not is_fa3_supported(): if not is_fa3_supported():
...@@ -319,10 +318,15 @@ def flash_attn_varlen_func( ...@@ -319,10 +318,15 @@ def flash_attn_varlen_func(
"flash_attn at sgl-kernel is only supported on sm90 and above" "flash_attn at sgl-kernel is only supported on sm90 and above"
) )
# FA3 requires max_seqlen_q and max_seqlen_k
if max_seqlen_q is None or max_seqlen_k is None:
raise ValueError("max_seqlen_q and max_seqlen_k are required for FA3")
if softmax_scale is None: if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5 -0.5
) )
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q, q,
...@@ -352,6 +356,7 @@ def flash_attn_varlen_func( ...@@ -352,6 +356,7 @@ def flash_attn_varlen_func(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
attention_chunk,
softcap, softcap,
is_rotary_interleaved=False, is_rotary_interleaved=False,
scheduler_metadata=None, scheduler_metadata=None,
......
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a/tests/cute/test_flash_attn.py # Adapted from https://github.com/Dao-AILab/flash-attention/blob/8ecf128f683266735ba68e3c106ff67a2611886e/tests/cute/test_flash_attn.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
...@@ -10,12 +10,25 @@ import pytest ...@@ -10,12 +10,25 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
try:
from flash_attn.layers.rotary import apply_rotary_emb
except ImportError:
apply_rotary_emb = None
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from utils import is_hopper from sgl_kernel.testing.rotary_embedding import _apply_rotary_emb as apply_rotary_emb
# from utils import is_hopper # Not used in this test
# Force sgl_kernel.flash_attn wrappers to use FA4 (Cute-DSL) implementations.
# The wrappers accept a superset of args; for FA4, extra args are ignored.
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4) flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
flash_attn_with_kvcache = partial(flash_attn_with_kvcache, ver=4) flash_attn_with_kvcache = partial(flash_attn_with_kvcache, ver=4)
# Skip this test on Hopper machine
skip_condition = torch.cuda.get_device_capability() < (10, 0)
def unpad_input(hidden_states, attention_mask, unused_mask=None): def unpad_input(hidden_states, attention_mask, unused_mask=None):
""" """
...@@ -88,6 +101,11 @@ def generate_random_padding_mask( ...@@ -88,6 +101,11 @@ def generate_random_padding_mask(
lengths = torch.randint( lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device
) )
else:
# This should never happen due to the assertion above, but for linter
lengths = torch.full(
(batch_size, 1), max_seqlen, device=device, dtype=torch.int32
)
if zero_lengths: if zero_lengths:
# Generate zero-lengths every 5 batches and the last batch. # Generate zero-lengths every 5 batches and the last batch.
...@@ -482,8 +500,7 @@ def attention_ref( ...@@ -482,8 +500,7 @@ def attention_ref(
@pytest.mark.skipif( @pytest.mark.skipif(
is_hopper(), skip_condition, reason="FA4 Requires compute capability of 10 or above."
reason="skip on hopper",
) )
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
...@@ -497,8 +514,8 @@ def attention_ref( ...@@ -497,8 +514,8 @@ def attention_ref(
@pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0]) # @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True])
...@@ -522,11 +539,11 @@ def attention_ref( ...@@ -522,11 +539,11 @@ def attention_ref(
(64, 128), (64, 128),
(128, 128), (128, 128),
(256, 256), (256, 256),
(113, 203), # (113, 203),
(128, 217), # (128, 217),
(113, 211), # (113, 211),
(108, 256), # (108, 256),
(256, 512), # (256, 512),
(307, 256), (307, 256),
(640, 128), (640, 128),
(512, 256), (512, 256),
...@@ -658,25 +675,7 @@ def test_flash_attn_varlen_output( ...@@ -658,25 +675,7 @@ def test_flash_attn_varlen_output(
if causal or local: if causal or local:
key_padding_mask = query_padding_mask key_padding_mask = query_padding_mask
( result = generate_qkv(
q_unpad,
k_unpad,
v_unpad,
qv_unpad,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
qv,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
q, q,
k, k,
v, v,
...@@ -687,6 +686,25 @@ def test_flash_attn_varlen_output( ...@@ -687,6 +686,25 @@ def test_flash_attn_varlen_output(
query_unused_mask=query_unused_mask, query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask, key_unused_mask=key_unused_mask,
) )
(
q_unpad, # 0
k_unpad, # 1
v_unpad, # 2
qv_unpad, # 3
cu_seqlens_q, # 4
cu_seqlens_k, # 5
seqused_q, # 6
seqused_k, # 7
max_seqlen_q, # 8
max_seqlen_k, # 9
q, # 10
k, # 11
v, # 12
qv, # 13
output_pad_fn, # 14
dq_pad_fn, # 15
dk_pad_fn, # 16
) = result
q_unpad, k_unpad, v_unpad = [ q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
] ]
...@@ -746,20 +764,16 @@ def test_flash_attn_varlen_output( ...@@ -746,20 +764,16 @@ def test_flash_attn_varlen_output(
v_unpad, v_unpad,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=None, # max_seqlen_q and max_seqlen_k not needed for FA4
max_seqlen_k=None, seqused_q=seqused_q,
# seqused_q=seqused_q, seqused_k=seqused_k,
# seqused_k=seqused_k,
causal=causal, causal=causal,
# qv=qv_unpad,
# q_descale=q_descale,
# k_descale=k_descale, v_descale=v_descale,
window_size=window_size, window_size=window_size,
# attention_chunk=attention_chunk,
sinks=learnable_sink,
softcap=softcap, softcap=softcap,
sinks=learnable_sink, # FA4 uses learnable_sink, not sinks
pack_gqa=pack_gqa, pack_gqa=pack_gqa,
return_softmax_lse=True, return_softmax_lse=True,
ver=4, # Use FA4
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
if query_unused_mask is not None: if query_unused_mask is not None:
...@@ -875,8 +889,7 @@ def test_flash_attn_varlen_output( ...@@ -875,8 +889,7 @@ def test_flash_attn_varlen_output(
@pytest.mark.skipif( @pytest.mark.skipif(
is_hopper(), skip_condition, reason="FA4 Requires compute capability of 10 or above."
reason="skip on hopper",
) )
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
...@@ -887,8 +900,8 @@ def test_flash_attn_varlen_output( ...@@ -887,8 +900,8 @@ def test_flash_attn_varlen_output(
# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [False, True])
@pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
...@@ -900,8 +913,8 @@ def test_flash_attn_varlen_output( ...@@ -900,8 +913,8 @@ def test_flash_attn_varlen_output(
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) # @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
@pytest.mark.parametrize("page_size", [None, 128]) # @pytest.mark.parametrize("page_size", [None, 128])
# @pytest.mark.parametrize("page_size", [128]) @pytest.mark.parametrize("page_size", [128])
# @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False, True])
@pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_leftpad", [False])
# @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False, True])
...@@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache( ...@@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache(
.to(dtype_ref) .to(dtype_ref)
) )
page_table = None page_table = None
num_blocks = None
else: else:
( (
k_cache, k_cache,
...@@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache( ...@@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache(
else: else:
k_cache_paged.copy_(k_cache_saved) k_cache_paged.copy_(k_cache_saved)
v_cache_paged.copy_(v_cache_saved) v_cache_paged.copy_(v_cache_saved)
# out, lse, *rest = flash_attn_with_kvcache( # For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache
out, lse, *rest = flash_attn_with_kvcache( # This matches the pattern from the original FA4 test
out, lse = flash_attn_varlen_func(
q if not varlen_q else q_unpad, q if not varlen_q else q_unpad,
k_cache if page_size is None else k_cache_paged, k_cache if page_size is None else k_cache_paged,
v_cache if page_size is None else v_cache_paged, v_cache if page_size is None else v_cache_paged,
# k if not new_kv or not varlen_q else k_unpad,
# v if not new_kv or not varlen_q else v_unpad,
# qv=qv if not varlen_q else qv_unpad,
# rotary_cos=cos,
# rotary_sin=sin,
cache_seqlens=cache_seqlens,
# cache_batch_idx=cache_batch_idx,
# cache_leftpad=cache_leftpad,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
# cu_seqlens_k_new=cu_seqlens_k_new, cu_seqlens_k=None, # FA4 doesn't use cu_seqlens_k for KV cache
# rotary_seqlens=rotary_seqlens, # max_seqlen_q and max_seqlen_k not needed for FA4
seqused_k=cache_seqlens, # Use cache_seqlens as seqused_k
page_table=page_table,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
sinks=learnable_sink, sinks=learnable_sink, # FA4 uses learnable_sink, not sinks
# attention_chunk=attention_chunk, softcap=0.0,
# rotary_interleaved=rotary_interleaved, pack_gqa=None,
# scheduler_metadata=scheduler_metadata,
# num_splits=num_splits,
return_softmax_lse=True, return_softmax_lse=True,
ver=4, # Use FA4
) )
if varlen_q: if varlen_q:
out = output_pad_fn(out) out = output_pad_fn(out)
......
...@@ -169,6 +169,7 @@ suites = { ...@@ -169,6 +169,7 @@ suites = {
TestFile("test_disaggregation_pp.py", 140), TestFile("test_disaggregation_pp.py", 140),
], ],
"per-commit-4-gpu-b200": [ "per-commit-4-gpu-b200": [
# TestFile("test_flash_attention_4.py"),
# TestFile("test_gpt_oss_4gpu.py", 600), # TestFile("test_gpt_oss_4gpu.py", 600),
# TestFile("test_deepseek_v3_fp4_4gpu.py", 3600), # TestFile("test_deepseek_v3_fp4_4gpu.py", 3600),
], ],
......
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestFlashAttention4(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--mem-fraction-static",
"0.8",
"--prefill-attention-backend",
"fa4",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=4,
data_path=None,
num_questions=100,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.65)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment