Unverified Commit a2f7218a authored by cicirori's avatar cicirori Committed by GitHub
Browse files

support using fa4 on deepseek on blackwell (#9928)

parent 311de47b
......@@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs):
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
os.environ["TRTLLM_ENABLE_PDL"] = "1"
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
# Default to warning level, to avoid too many logs
os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
# Need to set log to console, otherwise the log level won't take effect
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
# Can also be passed as argument
os.environ["SGLANG_RUN_ID"] = (
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
......
......@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
fa_impl_ver=3,
):
super().__init__()
......@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
)
self.speculative_step_id = speculative_step_id
self.fa_impl_ver = fa_impl_ver
# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
......@@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks
......@@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
......@@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
**kwargs,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
......@@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
......@@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend):
return output, lse
return output
else:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
......@@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
if k is not None:
assert v is not None
if save_kv_cache:
......@@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks
......
......@@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend):
self.model_runner = model_runner
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend
self.data_type = model_runner.kv_cache_dtype
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
"""
......
......@@ -516,6 +516,7 @@ class ModelRunner:
"aiter",
"flashinfer",
"fa3",
"fa4",
"triton",
"flashmla",
"cutlass_mla",
......@@ -1800,6 +1801,15 @@ class ModelRunner:
)
return FlashAttentionBackend(self)
elif backend_str == "fa4":
assert (
self.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return FlashAttentionBackend(self, fa_impl_ver=4)
elif backend_str == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend,
......
......@@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "fa4":
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV
elif attention_backend == "trtllm_mla":
original_mode = getattr(forward_batch, "_original_forward_mode", None)
if (
......
......@@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [
# NVIDIA specific
"cutlass_mla",
"fa3",
"fa4",
"flashinfer",
"flashmla",
"trtllm_mla",
......
......@@ -4,9 +4,15 @@
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
import copy
import gc
import logging
import math
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
......@@ -20,6 +26,22 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _reason_recompile(compile_key, jit_func):
compile_cache = jit_func.compile_cache
compile_key_map = jit_func.compile_key_map
if not compile_cache:
return "not compiled yet"
for k, v in compile_cache.items():
if k == compile_key:
continue
if len(k) != len(compile_key):
continue
for i in range(len(k)):
if k[i] != compile_key[i]:
return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} "
return "unknown reason"
torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
......@@ -254,6 +276,9 @@ def _flash_attn_fwd(
compute_capability,
)
if compile_key not in _flash_attn_fwd.compile_cache:
logger.info(
f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}"
)
if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
......@@ -335,8 +360,85 @@ def _flash_attn_fwd(
_flash_attn_fwd.compile_cache = {}
_flash_attn_fwd.compile_key_map = [
"dtype",
"head_dim",
"head_dim_v",
"qhead_per_kvhead",
"causal",
"softcap is not None",
"lse is None",
"cu_seqlens_q is None",
"cu_seqlens_k is None",
"seqused_q is None",
"seqused_k is None",
"page_table is not None",
"window_size_left is not None",
"window_size_right is not None",
"learnable_sink is not None",
"m_block_size",
"n_block_size",
"num_threads",
"pack_gqa",
"compute_capability",
]
def warmup_flash_attn(f):
"""
Decorator for flash_attn_varlen_func:
- On the first call, run several warmup passes with different flag combinations
- Warmups are executed sequentially to minimize peak GPU memory usage
- Does not modify user-provided tensors (clones data)
- Easy to extend with more compile-key dimensions
"""
done = False
def _clone_args(args, kwargs):
"""Clone tensor arguments to avoid sharing storage; deepcopy for others."""
def maybe_clone(x):
if isinstance(x, torch.Tensor):
return x.clone()
return copy.deepcopy(x)
return tuple(maybe_clone(a) for a in args), {
k: maybe_clone(v) for k, v in kwargs.items()
}
def _run_warmups(args, kwargs):
"""Run warmup calls sequentially and release memory after each."""
base_args, base_kwargs = _clone_args(args, kwargs)
# Warmup combinations for return_softmax_lse and causal
combos = [
dict(return_softmax_lse=False, causal=False),
dict(return_softmax_lse=False, causal=True),
dict(return_softmax_lse=True, causal=False),
dict(return_softmax_lse=True, causal=True),
]
for combo in combos:
wa, wk = _clone_args(base_args, base_kwargs)
wk.update(combo)
with torch.cuda.stream(torch.cuda.current_stream()):
f(*wa, **wk)
del wa, wk
torch.cuda.empty_cache()
gc.collect()
def wrapper(*args, **kwargs):
nonlocal done
if not done:
logger.info("Running flash_attn_varlen_func warmup passes...")
_run_warmups(args, kwargs)
done = True
return f(*args, **kwargs)
return wrapper
@warmup_flash_attn
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment