Unverified Commit 2bc61dd1 authored by li-kesen's avatar li-kesen Committed by GitHub
Browse files

Remove hybrid_linear_attn attention backend and refactor attention registry (#10816)


Co-authored-by: default avatarYi Zhang <1109276519@qq.com>
parent 6535fda1
import logging
logger = logging.getLogger(__name__)
ATTENTION_BACKENDS = {} ATTENTION_BACKENDS = {}
...@@ -158,35 +162,37 @@ def create_dual_chunk_flash_attn_backend(runner): ...@@ -158,35 +162,37 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner) return DualChunkFlashAttentionBackend(runner)
@register_attention_backend("hybrid_linear_attn") def attn_backend_wrapper(runner, full_attn_backend):
def create_hybrid_linear_attn_backend(runner): """
assert ( Wrapper for special models like hybrid GDN, so we don't
runner.is_hybrid_gdn need to change the code of the original attention backend.
), "hybrid_linear_attn backend can only be used with hybrid GDN models." """
from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( assert not (
HybridLinearAttnBackend, runner.is_hybrid_gdn and runner.use_mla_backend
MambaAttnBackend, ), "hybrid_gdn can only be used with non-MLA models."
)
from sglang.srt.utils import is_blackwell, is_npu # wrap for hybrid GDN models
if runner.is_hybrid_gdn:
if is_npu(): from sglang.srt.utils import is_blackwell, is_npu
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
if is_blackwell():
full_attn_backend = AscendAttnBackend(runner) assert (
elif is_blackwell(): runner.server_args.attention_backend == "triton"
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
full_attn_backend = TritonAttnBackend(runner) assert (
else: runner.server_args.attention_backend == "ascend"
from sglang.srt.layers.attention.flashattention_backend import ( ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
FlashAttentionBackend, logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
) )
full_attn_backend = FlashAttentionBackend(runner) linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
linear_attn_backend = MambaAttnBackend(runner) return HybridLinearAttnBackend(
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids full_attn_backend, linear_attn_backend, full_attn_layers
)
return HybridLinearAttnBackend( return full_attn_backend
full_attn_backend, linear_attn_backend, full_attn_layers
)
...@@ -60,7 +60,10 @@ from sglang.srt.eplb.expert_location import ( ...@@ -60,7 +60,10 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata, set_global_expert_location_metadata,
) )
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS from sglang.srt.layers.attention.attention_registry import (
ATTENTION_BACKENDS,
attn_backend_wrapper,
)
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_group, get_attention_tp_group,
...@@ -347,7 +350,6 @@ class ModelRunner: ...@@ -347,7 +350,6 @@ class ModelRunner:
if self.is_hybrid_gdn: if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache") logger.warning("Hybrid GDN model detected, disable radix cache")
self.server_args.disable_radix_cache = True self.server_args.disable_radix_cache = True
self.server_args.attention_backend = "hybrid_linear_attn"
if self.server_args.max_mamba_cache_size is None: if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None: if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = ( self.server_args.max_mamba_cache_size = (
...@@ -1648,10 +1650,9 @@ class ModelRunner: ...@@ -1648,10 +1650,9 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator # Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
if _is_npu and self.server_args.attention_backend in [ if _is_npu and (
"ascend", self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
"hybrid_linear_attn", ):
]:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
...@@ -1764,7 +1765,8 @@ class ModelRunner: ...@@ -1764,7 +1765,8 @@ class ModelRunner:
def _get_attention_backend_from_str(self, backend_str: str): def _get_attention_backend_from_str(self, backend_str: str):
if backend_str not in ATTENTION_BACKENDS: if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}") raise ValueError(f"Invalid attention backend: {backend_str}")
return ATTENTION_BACKENDS[backend_str](self) full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
return attn_backend_wrapper(self, full_attention_backend)
def init_double_sparsity_channel_config(self, selected_channel): def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj" selected_channel = "." + selected_channel + "_proj"
......
...@@ -100,7 +100,6 @@ ATTENTION_BACKEND_CHOICES = [ ...@@ -100,7 +100,6 @@ ATTENTION_BACKEND_CHOICES = [
"trtllm_mla", "trtllm_mla",
"trtllm_mha", "trtllm_mha",
"dual_chunk_flash_attn", "dual_chunk_flash_attn",
"hybrid_linear_attn",
# AMD specific # AMD specific
"aiter", "aiter",
"wave", "wave",
...@@ -801,7 +800,7 @@ class ServerArgs: ...@@ -801,7 +800,7 @@ class ServerArgs:
self.speculative_algorithm is None self.speculative_algorithm is None
), "Speculative decoding is currently not supported with Flex Attention backend" ), "Speculative decoding is currently not supported with Flex Attention backend"
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]: if is_npu() and self.attention_backend in ["ascend"]:
logger.warning( logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128." "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment