"git@developer.sourcefind.cn:change/sglang.git" did not exist on "403566bcca66cd892804d0b379fb37cb213e5074"
Unverified Commit edefab0c authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[2/2] Support MHA prefill with FlashAttention 4. (#10937)


Co-authored-by: default avatarHieu Pham <hyhieu@gmail.com>
parent 97cd38e5
...@@ -53,7 +53,7 @@ dependencies = [ ...@@ -53,7 +53,7 @@ dependencies = [
"scipy", "scipy",
"sentencepiece", "sentencepiece",
"setproctitle", "setproctitle",
"sgl-kernel==0.3.14.post1", "sgl-kernel==0.3.15",
"soundfile==0.13.1", "soundfile==0.13.1",
"tiktoken", "tiktoken",
"timm==1.0.16", "timm==1.0.16",
......
...@@ -65,7 +65,7 @@ tracing = [ ...@@ -65,7 +65,7 @@ tracing = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.3.14.post1", "sgl-kernel==0.3.15",
"torch==2.8.0", "torch==2.8.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
......
...@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.3.14", "0.3.15",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
...@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner): ...@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
@register_attention_backend("fa4") @register_attention_backend("fa4")
def create_flashattention_v4_backend(runner): def create_flashattention_v4_backend(runner):
assert (
runner.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
return FlashAttentionBackend(runner, fa_impl_ver=4) return FlashAttentionBackend(runner, fa_impl_ver=4)
......
...@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill # Use Flash Attention for prefill
if not self.use_mla: if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention # Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id layer.layer_id
......
...@@ -1746,16 +1746,10 @@ class ModelRunner: ...@@ -1746,16 +1746,10 @@ class ModelRunner:
def _get_attention_backend(self): def _get_attention_backend(self):
"""Init attention kernel backend.""" """Init attention kernel backend."""
self.decode_attention_backend_str = ( self.prefill_attention_backend_str, self.decode_attention_backend_str = (
self.server_args.decode_attention_backend self.server_args.get_attention_backends()
if self.server_args.decode_attention_backend
else self.server_args.attention_backend
)
self.prefill_attention_backend_str = (
self.server_args.prefill_attention_backend
if self.server_args.prefill_attention_backend
else self.server_args.attention_backend
) )
if self.decode_attention_backend_str != self.prefill_attention_backend_str: if self.decode_attention_backend_str != self.prefill_attention_backend_str:
from sglang.srt.layers.attention.hybrid_attn_backend import ( from sglang.srt.layers.attention.hybrid_attn_backend import (
HybridAttnBackend, HybridAttnBackend,
......
...@@ -464,6 +464,19 @@ class ServerArgs: ...@@ -464,6 +464,19 @@ class ServerArgs:
enable_pdmux: bool = False enable_pdmux: bool = False
sm_group_num: int = 3 sm_group_num: int = 3
def get_attention_backends(server_args):
prefill_attention_backend_str = (
server_args.prefill_attention_backend
if server_args.prefill_attention_backend
else server_args.attention_backend
)
decode_attention_backend_str = (
server_args.decode_attention_backend
if server_args.decode_attention_backend
else server_args.attention_backend
)
return prefill_attention_backend_str, decode_attention_backend_str
def __post_init__(self): def __post_init__(self):
""" """
Orchestrates the handling of various server arguments, ensuring proper configuration and validation. Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
...@@ -748,20 +761,28 @@ class ServerArgs: ...@@ -748,20 +761,28 @@ class ServerArgs:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0] model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]: if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None: if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
if is_cuda() and is_sm100_supported(): if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha" self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported(): elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3" self.attention_backend = "fa3"
else: else:
self.attention_backend = "triton" self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info( supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM" prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
)
assert ( assert (
self.attention_backend in supported_backends prefill_attn_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'" and decode_attn_backend in supported_backends
), (
f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n"
f"- Prefill: {prefill_attn_backend}\n"
f"- Decode: {decode_attn_backend}\n"
)
if is_sm100_supported(): if is_sm100_supported():
if not self.enable_dp_attention: if not self.enable_dp_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment