Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2a2ff9a8
Unverified
Commit
2a2ff9a8
authored
Sep 18, 2025
by
Yineng Zhang
Committed by
GitHub
Sep 18, 2025
Browse files
refactor: use registry for _get_attention_backend_from_str (#10629)
parent
5291f32d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
195 additions
and
148 deletions
+195
-148
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+192
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-148
No files found.
python/sglang/srt/layers/attention/attention_registry.py
0 → 100644
View file @
2a2ff9a8
ATTENTION_BACKENDS
=
{}
def
register_attention_backend
(
name
):
def
decorator
(
fn
):
ATTENTION_BACKENDS
[
name
]
=
fn
return
fn
return
decorator
@
register_attention_backend
(
"flashinfer"
)
def
create_flashinfer_backend
(
runner
):
import
torch
if
not
runner
.
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
# Init streams
if
runner
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
if
(
not
hasattr
(
runner
,
"plan_stream_for_flashinfer"
)
or
not
runner
.
plan_stream_for_flashinfer
):
runner
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
return
FlashInferAttnBackend
(
runner
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
runner
)
@
register_attention_backend
(
"trtllm_mla"
)
def
create_trtllm_mla_backend
(
runner
):
if
not
runner
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend can only be used with MLA models."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
runner
)
@
register_attention_backend
(
"aiter"
)
def
create_aiter_backend
(
runner
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
runner
)
@
register_attention_backend
(
"wave"
)
def
create_wave_backend
(
runner
):
from
sglang.srt.layers.attention.wave_backend
import
WaveAttnBackend
return
WaveAttnBackend
(
runner
)
@
register_attention_backend
(
"ascend"
)
def
create_ascend_backend
(
runner
):
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
return
AscendAttnBackend
(
runner
)
@
register_attention_backend
(
"triton"
)
def
create_triton_backend
(
runner
):
assert
not
runner
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if
runner
.
server_args
.
enable_double_sparsity
:
from
sglang.srt.layers.attention.double_sparsity_backend
import
(
DoubleSparseAttnBackend
,
)
return
DoubleSparseAttnBackend
(
runner
)
else
:
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
runner
)
@
register_attention_backend
(
"torch_native"
)
def
create_torch_native_backend
(
runner
):
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
return
TorchNativeAttnBackend
(
runner
)
@
register_attention_backend
(
"flex_attention"
)
def
create_flex_attention_backend
(
runner
):
from
sglang.srt.layers.attention.torch_flex_backend
import
TorchFlexAttnBackend
return
TorchFlexAttnBackend
(
runner
)
@
register_attention_backend
(
"flashmla"
)
def
create_flashmla_backend
(
runner
):
from
sglang.srt.layers.attention.flashmla_backend
import
FlashMLABackend
return
FlashMLABackend
(
runner
)
@
register_attention_backend
(
"fa3"
)
def
create_flashattention_v3_backend
(
runner
):
import
torch
assert
(
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
not
runner
.
use_mla_backend
)
or
torch
.
cuda
.
get_device_capability
()[
0
]
==
9
,
(
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
"Please use `--attention-backend flashinfer`."
)
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
return
FlashAttentionBackend
(
runner
)
@
register_attention_backend
(
"fa4"
)
def
create_flashattention_v4_backend
(
runner
):
assert
(
runner
.
use_mla_backend
),
"FlashAttention v4 Support is at an early stage, only MLA model supported now"
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
return
FlashAttentionBackend
(
runner
,
fa_impl_ver
=
4
)
@
register_attention_backend
(
"cutlass_mla"
)
def
create_cutlass_mla_backend
(
runner
):
from
sglang.srt.layers.attention.cutlass_mla_backend
import
CutlassMLABackend
return
CutlassMLABackend
(
runner
)
@
register_attention_backend
(
"trtllm_mha"
)
def
create_trtllm_mha_backend
(
runner
):
if
runner
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mha backend can only be used with non-MLA models."
)
from
sglang.srt.layers.attention.trtllm_mha_backend
import
TRTLLMHAAttnBackend
return
TRTLLMHAAttnBackend
(
runner
)
@
register_attention_backend
(
"intel_amx"
)
def
create_intel_amx_backend
(
runner
):
from
sglang.srt.layers.attention.intel_amx_backend
import
IntelAMXAttnBackend
return
IntelAMXAttnBackend
(
runner
)
@
register_attention_backend
(
"dual_chunk_flash_attn"
)
def
create_dual_chunk_flash_attn_backend
(
runner
):
from
sglang.srt.layers.attention.dual_chunk_flashattention_backend
import
(
DualChunkFlashAttentionBackend
,
)
return
DualChunkFlashAttentionBackend
(
runner
)
@
register_attention_backend
(
"hybrid_linear_attn"
)
def
create_hybrid_linear_attn_backend
(
runner
):
assert
(
runner
.
is_hybrid_gdn
),
"hybrid_linear_attn backend can only be used with hybrid GDN models."
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
HybridLinearAttnBackend
,
MambaAttnBackend
,
)
from
sglang.srt.utils
import
is_blackwell
,
is_npu
if
is_npu
():
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
full_attn_backend
=
AscendAttnBackend
(
runner
)
elif
is_blackwell
():
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
full_attn_backend
=
TritonAttnBackend
(
runner
)
else
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
full_attn_backend
=
FlashAttentionBackend
(
runner
)
linear_attn_backend
=
MambaAttnBackend
(
runner
)
full_attn_layers
=
runner
.
model_config
.
hf_config
.
full_attention_layer_ids
return
HybridLinearAttnBackend
(
full_attn_backend
,
linear_attn_backend
,
full_attn_layers
)
python/sglang/srt/model_executor/model_runner.py
View file @
2a2ff9a8
...
...
@@ -60,6 +60,7 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata
,
)
from
sglang.srt.eplb.expert_location_updater
import
ExpertLocationUpdater
from
sglang.srt.layers.attention.attention_registry
import
ATTENTION_BACKENDS
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
...
...
@@ -1733,155 +1734,9 @@ class ModelRunner:
return
attn_backend
def
_get_attention_backend_from_str
(
self
,
backend_str
:
str
):
if
backend_str
==
"flashinfer"
:
if
not
self
.
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
if
(
not
hasattr
(
self
,
"plan_stream_for_flashinfer"
)
or
not
self
.
plan_stream_for_flashinfer
):
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
return
FlashInferAttnBackend
(
self
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
self
)
elif
backend_str
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"wave"
:
from
sglang.srt.layers.attention.wave_backend
import
WaveAttnBackend
return
WaveAttnBackend
(
self
)
elif
backend_str
==
"ascend"
:
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
return
AscendAttnBackend
(
self
)
elif
backend_str
==
"triton"
:
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
if
self
.
server_args
.
enable_double_sparsity
:
from
sglang.srt.layers.attention.double_sparsity_backend
import
(
DoubleSparseAttnBackend
,
)
return
DoubleSparseAttnBackend
(
self
)
else
:
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
self
)
elif
backend_str
==
"torch_native"
:
from
sglang.srt.layers.attention.torch_native_backend
import
(
TorchNativeAttnBackend
,
)
return
TorchNativeAttnBackend
(
self
)
elif
backend_str
==
"flex_attention"
:
from
sglang.srt.layers.attention.torch_flex_backend
import
(
TorchFlexAttnBackend
,
)
return
TorchFlexAttnBackend
(
self
)
elif
backend_str
==
"flashmla"
:
from
sglang.srt.layers.attention.flashmla_backend
import
FlashMLABackend
return
FlashMLABackend
(
self
)
elif
backend_str
==
"fa3"
:
assert
(
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
not
self
.
use_mla_backend
)
or
torch
.
cuda
.
get_device_capability
()[
0
]
==
9
,
(
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
"Please use `--attention-backend flashinfer`."
)
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
)
elif
backend_str
==
"fa4"
:
assert
(
self
.
use_mla_backend
),
"FlashAttention v4 Support is at an early stage, only MLA model supported now"
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
,
fa_impl_ver
=
4
)
elif
backend_str
==
"cutlass_mla"
:
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
CutlassMLABackend
,
)
return
CutlassMLABackend
(
self
)
elif
backend_str
==
"trtllm_mla"
:
if
not
self
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend can only be used with MLA models."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
)
elif
backend_str
==
"trtllm_mha"
:
if
self
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mha backend can only be used with non-MLA models."
)
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnBackend
,
)
return
TRTLLMHAAttnBackend
(
self
)
elif
backend_str
==
"intel_amx"
:
from
sglang.srt.layers.attention.intel_amx_backend
import
(
IntelAMXAttnBackend
,
)
return
IntelAMXAttnBackend
(
self
)
elif
backend_str
==
"dual_chunk_flash_attn"
:
from
sglang.srt.layers.attention.dual_chunk_flashattention_backend
import
(
DualChunkFlashAttentionBackend
,
)
return
DualChunkFlashAttentionBackend
(
self
)
elif
backend_str
==
"hybrid_linear_attn"
:
assert
(
self
.
is_hybrid_gdn
),
"hybrid_linear_attn backend can only be used with hybrid GDN models."
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
HybridLinearAttnBackend
,
MambaAttnBackend
,
)
if
_is_npu
:
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
full_attn_backend
=
AscendAttnBackend
(
self
)
elif
is_blackwell
():
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
full_attn_backend
=
TritonAttnBackend
(
self
)
else
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
full_attn_backend
=
FlashAttentionBackend
(
self
)
linear_attn_backend
=
MambaAttnBackend
(
self
)
full_attn_layers
=
self
.
model_config
.
hf_config
.
full_attention_layer_ids
return
HybridLinearAttnBackend
(
full_attn_backend
,
linear_attn_backend
,
full_attn_layers
)
else
:
if
backend_str
not
in
ATTENTION_BACKENDS
:
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
return
ATTENTION_BACKENDS
[
backend_str
](
self
)
def
init_double_sparsity_channel_config
(
self
,
selected_channel
):
selected_channel
=
"."
+
selected_channel
+
"_proj"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment