Unverified Commit edefab0c authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[2/2] Support MHA prefill with FlashAttention 4. (#10937)


Co-authored-by: default avatarHieu Pham <hyhieu@gmail.com>
parent 97cd38e5
......@@ -53,7 +53,7 @@ dependencies = [
"scipy",
"sentencepiece",
"setproctitle",
"sgl-kernel==0.3.14.post1",
"sgl-kernel==0.3.15",
"soundfile==0.13.1",
"tiktoken",
"timm==1.0.16",
......
......@@ -65,7 +65,7 @@ tracing = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.14.post1",
"sgl-kernel==0.3.15",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
......
......@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version(
"sgl-kernel",
"0.3.14",
"0.3.15",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
......
......@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
assert (
runner.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
return FlashAttentionBackend(runner, fa_impl_ver=4)
......
......@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
......
......@@ -1746,16 +1746,10 @@ class ModelRunner:
def _get_attention_backend(self):
"""Init attention kernel backend."""
self.decode_attention_backend_str = (
self.server_args.decode_attention_backend
if self.server_args.decode_attention_backend
else self.server_args.attention_backend
)
self.prefill_attention_backend_str = (
self.server_args.prefill_attention_backend
if self.server_args.prefill_attention_backend
else self.server_args.attention_backend
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
self.server_args.get_attention_backends()
)
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
from sglang.srt.layers.attention.hybrid_attn_backend import (
HybridAttnBackend,
......
......@@ -464,6 +464,19 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
def get_attention_backends(server_args):
prefill_attention_backend_str = (
server_args.prefill_attention_backend
if server_args.prefill_attention_backend
else server_args.attention_backend
)
decode_attention_backend_str = (
server_args.decode_attention_backend
if server_args.decode_attention_backend
else server_args.attention_backend
)
return prefill_attention_backend_str, decode_attention_backend_str
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
......@@ -748,20 +761,28 @@ class ServerArgs:
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
prefill_attn_backend in supported_backends
and decode_attn_backend in supported_backends
), (
f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n"
f"- Prefill: {prefill_attn_backend}\n"
f"- Decode: {decode_attn_backend}\n"
)
if is_sm100_supported():
if not self.enable_dp_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment