Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7b36c47b
Unverified
Commit
7b36c47b
authored
Oct 25, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 25, 2025
Browse files
Clean up attention backend selection code & Other minor rename (#12136)
parent
773d89da
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
282 additions
and
275 deletions
+282
-275
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+5
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-9
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+15
-12
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-145
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+177
-31
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+73
-78
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
7b36c47b
...
@@ -498,6 +498,11 @@ async def get_server_info():
...
@@ -498,6 +498,11 @@ async def get_server_info():
internal_states
:
List
[
Dict
[
Any
,
Any
]]
=
(
internal_states
:
List
[
Dict
[
Any
,
Any
]]
=
(
await
_global_state
.
tokenizer_manager
.
get_internal_state
()
await
_global_state
.
tokenizer_manager
.
get_internal_state
()
)
)
# This field is not serializable.
if
hasattr
(
_global_state
.
tokenizer_manager
.
server_args
,
"model_config"
):
del
_global_state
.
tokenizer_manager
.
server_args
.
model_config
return
{
return
{
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
_global_state
.
scheduler_info
,
**
_global_state
.
scheduler_info
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
7b36c47b
...
@@ -2325,10 +2325,10 @@ class Scheduler(
...
@@ -2325,10 +2325,10 @@ class Scheduler(
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
forward_ct_decode
=
0
self
.
forward_ct_decode
=
0
self
.
spec_num_
total_
accepted_tokens
=
0
self
.
spec_num_accepted_tokens
=
0
self
.
spec_num_
total_
forward_ct
=
0
self
.
spec_num_forward_ct
=
0
self
.
cum_
spec_
accept_length
=
0
self
.
spec_
total_num_accepted_tokens
=
0
self
.
cum_
spec_
accept_coun
t
=
0
self
.
spec_
total_num_forward_c
t
=
0
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
if_success
=
True
...
@@ -2401,13 +2401,16 @@ class Scheduler(
...
@@ -2401,13 +2401,16 @@ class Scheduler(
self
.
tp_worker
.
model_runner
.
graph_mem_usage
,
2
self
.
tp_worker
.
model_runner
.
graph_mem_usage
,
2
)
)
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_
spec_
accept_coun
t
>
0
:
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
spec_
total_num_forward_c
t
>
0
:
ret
[
"avg_spec_accept_length"
]
=
(
ret
[
"avg_spec_accept_length"
]
=
(
self
.
cum_
spec_
accept_length
/
self
.
cum_
spec_
accept_coun
t
self
.
spec_
total_num_accepted_tokens
/
self
.
spec_
total_num_forward_c
t
)
)
if
RECORD_STEP_TIME
:
if
RECORD_STEP_TIME
:
ret
[
"step_time_dict"
]
=
self
.
step_time_dict
ret
[
"step_time_dict"
]
=
self
.
step_time_dict
# This field is not serializable.
ret
.
pop
(
"model_config"
,
None
)
return
GetInternalStateReqOutput
(
internal_state
=
ret
)
return
GetInternalStateReqOutput
(
internal_state
=
ret
)
def
set_internal_state
(
self
,
recv_req
:
SetInternalStateReq
):
def
set_internal_state
(
self
,
recv_req
:
SetInternalStateReq
):
...
@@ -2434,12 +2437,12 @@ class Scheduler(
...
@@ -2434,12 +2437,12 @@ class Scheduler(
if_success
=
False
if_success
=
False
break
break
if
if_success
:
if
if_success
:
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_
spec_
accept_coun
t
>
0
:
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
spec_
total_num_forward_c
t
>
0
:
avg_spec_accept_length
=
(
avg_spec_accept_length
=
(
self
.
cum_
spec_
accept_length
/
self
.
cum_
spec_
accept_coun
t
self
.
spec_
total_num_accepted_tokens
/
self
.
spec_
total_num_forward_c
t
)
)
logger
.
info
(
f
"
{
avg_spec_accept_length
=
}
"
)
logger
.
info
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
cum_
spec_
accept_length
=
self
.
cum_
spec_
accept_coun
t
=
0
self
.
spec_
total_num_accepted_tokens
=
self
.
spec_
total_num_forward_c
t
=
0
for
k
,
v
in
server_args_dict
.
items
():
for
k
,
v
in
server_args_dict
.
items
():
setattr
(
get_global_server_args
(),
k
,
v
)
setattr
(
get_global_server_args
(),
k
,
v
)
logger
.
info
(
f
"Global server args updated!
{
get_global_server_args
()
=
}
"
)
logger
.
info
(
f
"Global server args updated!
{
get_global_server_args
()
=
}
"
)
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
7b36c47b
...
@@ -39,10 +39,13 @@ class SchedulerMetricsMixin:
...
@@ -39,10 +39,13 @@ class SchedulerMetricsMixin:
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
step_time_dict
=
defaultdict
(
list
)
# Dict[batch size -> step time]
self
.
step_time_dict
=
defaultdict
(
list
)
# Dict[batch size -> step time]
self
.
spec_num_total_accepted_tokens
=
0
self
.
spec_num_total_forward_ct
=
0
# The number of accepted tokens and forward ct for the recent `decode_log_interval` batches (for logging)
self
.
cum_spec_accept_length
=
0
self
.
spec_num_accepted_tokens
=
0
self
.
cum_spec_accept_count
=
0
self
.
spec_num_forward_ct
=
0
# The total number of accepted tokens and forward ct for the whole server lifetime
self
.
spec_total_num_accepted_tokens
=
0
self
.
spec_total_num_forward_ct
=
0
self
.
kv_transfer_speed_gb_s
:
float
=
0.0
self
.
kv_transfer_speed_gb_s
:
float
=
0.0
self
.
kv_transfer_latency_ms
:
float
=
0.0
self
.
kv_transfer_latency_ms
:
float
=
0.0
...
@@ -67,8 +70,8 @@ class SchedulerMetricsMixin:
...
@@ -67,8 +70,8 @@ class SchedulerMetricsMixin:
)
)
def
update_spec_metrics
(
self
:
Scheduler
,
bs
:
int
,
num_accepted_tokens
:
int
):
def
update_spec_metrics
(
self
:
Scheduler
,
bs
:
int
,
num_accepted_tokens
:
int
):
self
.
spec_num_
total_
accepted_tokens
+=
num_accepted_tokens
+
bs
self
.
spec_num_accepted_tokens
+=
num_accepted_tokens
+
bs
self
.
spec_num_
total_
forward_ct
+=
bs
self
.
spec_num_forward_ct
+=
bs
self
.
num_generated_tokens
+=
num_accepted_tokens
self
.
num_generated_tokens
+=
num_accepted_tokens
def
log_prefill_stats
(
def
log_prefill_stats
(
...
@@ -253,20 +256,20 @@ class SchedulerMetricsMixin:
...
@@ -253,20 +256,20 @@ class SchedulerMetricsMixin:
spec_accept_rate
=
0
spec_accept_rate
=
0
else
:
else
:
spec_accept_length
=
(
spec_accept_length
=
(
self
.
spec_num_
total_
accepted_tokens
/
self
.
spec_num_
total_
forward_ct
self
.
spec_num_accepted_tokens
/
self
.
spec_num_forward_ct
)
)
# Calculate acceptance rate: accepted tokens / total draft tokens
# Calculate acceptance rate: accepted tokens / total draft tokens
total_draft_tokens
=
self
.
spec_num_
total_
forward_ct
*
(
total_draft_tokens
=
self
.
spec_num_forward_ct
*
(
(
self
.
server_args
.
speculative_num_steps
or
0
)
+
1
(
self
.
server_args
.
speculative_num_steps
or
0
)
+
1
)
)
spec_accept_rate
=
(
spec_accept_rate
=
(
self
.
spec_num_
total_
accepted_tokens
/
total_draft_tokens
self
.
spec_num_accepted_tokens
/
total_draft_tokens
if
total_draft_tokens
>
0
if
total_draft_tokens
>
0
else
0
else
0
)
)
self
.
cum_
spec_
accept_length
+=
self
.
spec_num_
total_
accepted_tokens
self
.
spec_
total_num_accepted_tokens
+=
self
.
spec_num_accepted_tokens
self
.
cum_
spec_
accept_coun
t
+=
self
.
spec_num_
total_
forward_ct
self
.
spec_
total_num_forward_c
t
+=
self
.
spec_num_forward_ct
self
.
spec_num_
total_
accepted_tokens
=
self
.
spec_num_
total_
forward_ct
=
0
self
.
spec_num_accepted_tokens
=
self
.
spec_num_forward_ct
=
0
msg
+=
f
"accept len:
{
spec_accept_length
:.
2
f
}
, accept rate:
{
spec_accept_rate
:.
2
f
}
, "
msg
+=
f
"accept len:
{
spec_accept_length
:.
2
f
}
, accept rate:
{
spec_accept_rate
:.
2
f
}
, "
cache_hit_rate
=
0.0
cache_hit_rate
=
0.0
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7b36c47b
...
@@ -131,13 +131,8 @@ from sglang.srt.utils import (
...
@@ -131,13 +131,8 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_bool_env_var
,
get_cpu_ids_by_node
,
get_cpu_ids_by_node
,
init_custom_process_group
,
init_custom_process_group
,
is_fa3_default_architecture
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_hopper_with_cuda_12_3
,
is_no_spec_infer_or_topk_one
,
is_npu
,
is_npu
,
is_sm100_supported
,
log_info_on_rank0
,
log_info_on_rank0
,
monkey_patch_p2p_access_check
,
monkey_patch_p2p_access_check
,
set_cuda_arch
,
set_cuda_arch
,
...
@@ -502,121 +497,6 @@ class ModelRunner:
...
@@ -502,121 +497,6 @@ class ModelRunner:
def
model_specific_adjustment
(
self
):
def
model_specific_adjustment
(
self
):
server_args
=
self
.
server_args
server_args
=
self
.
server_args
if
(
server_args
.
attention_backend
==
"intel_amx"
and
server_args
.
device
==
"cpu"
and
not
_is_cpu_amx_available
):
logger
.
info
(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args
.
attention_backend
=
"torch_native"
if
(
server_args
.
attention_backend
==
"intel_xpu"
and
server_args
.
device
==
"xpu"
and
not
_is_xpu_xmx_available
):
logger
.
info
(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
server_args
.
attention_backend
=
"triton"
if
server_args
.
prefill_attention_backend
is
not
None
and
(
server_args
.
prefill_attention_backend
==
server_args
.
decode_attention_backend
):
# override the default attention backend
server_args
.
attention_backend
=
server_args
.
prefill_attention_backend
if
(
getattr
(
self
.
model_config
.
hf_config
,
"dual_chunk_attention_config"
,
None
)
is
not
None
):
if
server_args
.
attention_backend
is
None
:
server_args
.
attention_backend
=
"dual_chunk_flash_attn"
logger
.
info
(
"Dual chunk attention is turned on by default."
)
elif
server_args
.
attention_backend
!=
"dual_chunk_flash_attn"
:
raise
ValueError
(
"Dual chunk attention is enabled, but attention backend is set to "
f
"
{
server_args
.
attention_backend
}
. Please set it to 'dual_chunk_flash_attn'."
)
if
server_args
.
attention_backend
is
None
:
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 We will use Flashinfer backend on blackwell.
2.3 Otherwise, we will use triton backend.
"""
if
not
self
.
use_mla_backend
:
# MHA architecture
if
(
is_hopper_with_cuda_12_3
()
and
is_no_spec_infer_or_topk_one
(
server_args
)
and
is_fa3_default_architecture
(
self
.
model_config
.
hf_config
)
):
server_args
.
attention_backend
=
"fa3"
elif
_is_hip
:
server_args
.
attention_backend
=
"aiter"
elif
_is_npu
:
server_args
.
attention_backend
=
"ascend"
else
:
server_args
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
else
:
# MLA architecture
if
is_hopper_with_cuda_12_3
():
server_args
.
attention_backend
=
"fa3"
elif
is_sm100_supported
():
server_args
.
attention_backend
=
"flashinfer"
elif
_is_hip
:
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
# TODO current aiter only support head number 16 or 128 head number
if
head_num
==
128
or
head_num
==
16
:
server_args
.
attention_backend
=
"aiter"
else
:
server_args
.
attention_backend
=
"triton"
elif
_is_npu
:
server_args
.
attention_backend
=
"ascend"
else
:
server_args
.
attention_backend
=
"triton"
log_info_on_rank0
(
logger
,
f
"Attention backend not explicitly specified. Use
{
server_args
.
attention_backend
}
backend by default."
,
)
elif
self
.
use_mla_backend
:
if
server_args
.
device
!=
"cpu"
:
if
server_args
.
attention_backend
in
MLA_ATTENTION_BACKENDS
:
logger
.
info
(
f
"MLA optimization is turned on. Use
{
server_args
.
attention_backend
}
backend."
)
else
:
raise
ValueError
(
f
"Invalid attention backend for MLA:
{
server_args
.
attention_backend
}
"
)
else
:
if
server_args
.
attention_backend
!=
"intel_amx"
:
raise
ValueError
(
"MLA optimization not supported on CPU except for intel_amx backend."
)
if
(
server_args
.
attention_backend
==
"fa3"
and
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
):
logger
.
warning
(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args
.
attention_backend
=
"triton"
if
server_args
.
enable_double_sparsity
:
if
server_args
.
enable_double_sparsity
:
logger
.
info
(
logger
.
info
(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
...
@@ -642,31 +522,6 @@ class ModelRunner:
...
@@ -642,31 +522,6 @@ class ModelRunner:
if
not
server_args
.
disable_chunked_prefix_cache
:
if
not
server_args
.
disable_chunked_prefix_cache
:
log_info_on_rank0
(
logger
,
"Chunked prefix cache is turned on."
)
log_info_on_rank0
(
logger
,
"Chunked prefix cache is turned on."
)
if
server_args
.
attention_backend
==
"aiter"
:
if
self
.
model_config
.
context_len
>
8192
:
self
.
mem_fraction_static
*=
0.85
if
(
server_args
.
enable_hierarchical_cache
and
server_args
.
hicache_io_backend
==
"kernel"
):
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
if
server_args
.
decode_attention_backend
is
None
:
if
not
self
.
use_mla_backend
:
server_args
.
decode_attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
else
:
server_args
.
decode_attention_backend
=
(
"flashinfer"
if
is_sm100_supported
()
else
"triton"
)
elif
server_args
.
decode_attention_backend
==
"fa3"
:
server_args
.
hicache_io_backend
=
"direct"
logger
.
warning
(
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
)
if
self
.
model_config
.
hf_config
.
model_type
==
"qwen3_vl_moe"
:
if
self
.
model_config
.
hf_config
.
model_type
==
"qwen3_vl_moe"
:
if
(
if
(
quantization_config
:
=
getattr
(
quantization_config
:
=
getattr
(
...
...
python/sglang/srt/server_args.py
View file @
7b36c47b
...
@@ -34,12 +34,16 @@ from sglang.srt.utils.common import (
...
@@ -34,12 +34,16 @@ from sglang.srt.utils.common import (
LORA_TARGET_ALL_MODULES
,
LORA_TARGET_ALL_MODULES
,
SUPPORTED_LORA_TARGET_MODULES
,
SUPPORTED_LORA_TARGET_MODULES
,
configure_ipv6
,
configure_ipv6
,
cpu_has_amx_support
,
get_device
,
get_device
,
get_device_memory_capacity
,
get_device_memory_capacity
,
get_device_sm
,
get_device_sm
,
is_cuda
,
is_cuda
,
is_fa3_default_architecture
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_hopper_with_cuda_12_3
,
is_no_spec_infer_or_topk_one
,
is_npu
,
is_npu
,
is_port_available
,
is_port_available
,
is_remote_url
,
is_remote_url
,
...
@@ -51,6 +55,7 @@ from sglang.srt.utils.common import (
...
@@ -51,6 +55,7 @@ from sglang.srt.utils.common import (
json_list_type
,
json_list_type
,
nullable_str
,
nullable_str
,
parse_connector_type
,
parse_connector_type
,
xpu_has_xmx_support
,
)
)
from
sglang.srt.utils.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.utils.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.utils
import
is_in_ci
from
sglang.utils
import
is_in_ci
...
@@ -545,6 +550,9 @@ class ServerArgs:
...
@@ -545,6 +550,9 @@ class ServerArgs:
# Apply model-specific adjustments.
# Apply model-specific adjustments.
self
.
_handle_model_specific_adjustments
()
self
.
_handle_model_specific_adjustments
()
# Handle Hicache settings.
self
.
_handle_hicache
()
# Set kernel backends.
# Set kernel backends.
self
.
_handle_sampling_backend
()
self
.
_handle_sampling_backend
()
self
.
_handle_attention_backend_compatibility
()
self
.
_handle_attention_backend_compatibility
()
...
@@ -567,9 +575,6 @@ class ServerArgs:
...
@@ -567,9 +575,6 @@ class ServerArgs:
# Handle pipeline parallelism.
# Handle pipeline parallelism.
self
.
_handle_pipeline_parallelism
()
self
.
_handle_pipeline_parallelism
()
# Handle Hicache settings.
self
.
_handle_hicache
()
# Handle speculative decoding logic.
# Handle speculative decoding logic.
self
.
_handle_speculative_decoding
()
self
.
_handle_speculative_decoding
()
...
@@ -779,11 +784,9 @@ class ServerArgs:
...
@@ -779,11 +784,9 @@ class ServerArgs:
else
0.88
else
0.88
)
)
# Lazy init to avoid circular import
# Multimodal models need more memory for the image processing,
# Multimodal models need more memory for the image processor
# so we adjust the mem_fraction_static accordingly.
from
sglang.srt.configs.model_config
import
ModelConfig
model_config
=
self
.
get_model_config
()
model_config
=
ModelConfig
.
from_server_args
(
self
)
if
model_config
.
is_multimodal
:
if
model_config
.
is_multimodal
:
self
.
adjust_mem_fraction_for_vlm
(
model_config
)
self
.
adjust_mem_fraction_for_vlm
(
model_config
)
...
@@ -1042,6 +1045,67 @@ class ServerArgs:
...
@@ -1042,6 +1045,67 @@ class ServerArgs:
)
)
def
_handle_attention_backend_compatibility
(
self
):
def
_handle_attention_backend_compatibility
(
self
):
model_config
=
self
.
get_model_config
()
use_mla_backend
=
self
.
use_mla_backend
()
if
self
.
prefill_attention_backend
is
not
None
and
(
self
.
prefill_attention_backend
==
self
.
decode_attention_backend
):
# override the default attention backend
self
.
attention_backend
=
self
.
prefill_attention_backend
# Pick the default attention backend if not specified
if
self
.
attention_backend
is
None
:
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 We will use Flashinfer backend on blackwell.
2.3 Otherwise, we will use triton backend.
"""
if
not
use_mla_backend
:
# MHA architecture
if
(
is_hopper_with_cuda_12_3
()
and
is_no_spec_infer_or_topk_one
(
self
)
and
is_fa3_default_architecture
(
self
.
model_config
.
hf_config
)
):
self
.
attention_backend
=
"fa3"
elif
is_hip
():
self
.
attention_backend
=
"aiter"
elif
is_npu
():
self
.
attention_backend
=
"ascend"
else
:
self
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
else
:
# MLA architecture
if
is_hopper_with_cuda_12_3
():
self
.
attention_backend
=
"fa3"
elif
is_sm100_supported
():
self
.
attention_backend
=
"flashinfer"
elif
is_hip
():
head_num
=
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
# TODO current aiter only support head number 16 or 128 head number
if
head_num
==
128
or
head_num
==
16
:
self
.
attention_backend
=
"aiter"
else
:
self
.
attention_backend
=
"triton"
elif
is_npu
():
self
.
attention_backend
=
"ascend"
else
:
self
.
attention_backend
=
"triton"
logger
.
warning
(
f
"Attention backend not explicitly specified. Use
{
self
.
attention_backend
}
backend by default."
)
# Torch native and flex attention backends
if
self
.
attention_backend
==
"torch_native"
:
if
self
.
attention_backend
==
"torch_native"
:
logger
.
warning
(
logger
.
warning
(
"Cuda graph is disabled because of using torch native attention backend"
"Cuda graph is disabled because of using torch native attention backend"
...
@@ -1057,12 +1121,7 @@ class ServerArgs:
...
@@ -1057,12 +1121,7 @@ class ServerArgs:
self
.
speculative_algorithm
is
None
self
.
speculative_algorithm
is
None
),
"Speculative decoding is currently not supported with Flex Attention backend"
),
"Speculative decoding is currently not supported with Flex Attention backend"
if
is_npu
()
and
self
.
attention_backend
in
[
"ascend"
]:
# Major NVIDIA platforms backends
logger
.
warning
(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
self
.
page_size
=
128
if
(
if
(
self
.
attention_backend
==
"flashmla"
self
.
attention_backend
==
"flashmla"
or
self
.
decode_attention_backend
==
"flashmla"
or
self
.
decode_attention_backend
==
"flashmla"
...
@@ -1117,19 +1176,13 @@ class ServerArgs:
...
@@ -1117,19 +1176,13 @@ class ServerArgs:
)
)
self
.
page_size
=
64
self
.
page_size
=
64
if
self
.
attention_backend
==
"
dual_chunk_flash_attn
"
:
if
self
.
attention_backend
==
"
fa3"
and
self
.
kv_cache_dtype
==
"fp8_e5m2
"
:
logger
.
warning
(
logger
.
warning
(
"Mixed chunk and radix cache are disabled when using dual-chunk flash attention backend"
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
)
self
.
enable_mixed_chunk
=
False
self
.
attention_backend
=
"triton"
self
.
disable_radix_cache
=
True
if
self
.
attention_backend
==
"intel_xpu"
:
if
self
.
page_size
not
in
[
32
,
64
,
128
]:
logger
.
warning
(
f
"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from
{
self
.
page_size
}
to 128."
)
self
.
page_size
=
128
if
self
.
attention_backend
==
"fa4"
or
self
.
decode_attention_backend
==
"fa4"
:
if
self
.
attention_backend
==
"fa4"
or
self
.
decode_attention_backend
==
"fa4"
:
raise
ValueError
(
raise
ValueError
(
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."
...
@@ -1140,6 +1193,66 @@ class ServerArgs:
...
@@ -1140,6 +1193,66 @@ class ServerArgs:
)
)
self
.
page_size
=
128
self
.
page_size
=
128
# AMD platforms backends
if
self
.
attention_backend
==
"aiter"
:
if
model_config
.
context_len
>
8192
:
self
.
mem_fraction_static
*=
0.90
# NPU platforms backends
if
is_npu
()
and
self
.
attention_backend
in
[
"ascend"
]:
logger
.
warning
(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
self
.
page_size
=
128
# Other platforms backends
if
(
self
.
attention_backend
==
"intel_amx"
and
self
.
device
==
"cpu"
and
not
cpu_has_amx_support
()
):
logger
.
warning
(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
self
.
attention_backend
=
"torch_native"
if
(
self
.
attention_backend
==
"intel_xpu"
and
self
.
device
==
"xpu"
and
not
xpu_has_xmx_support
()
):
logger
.
warning
(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
self
.
attention_backend
=
"triton"
if
self
.
attention_backend
==
"intel_xpu"
:
if
self
.
page_size
not
in
[
32
,
64
,
128
]:
logger
.
warning
(
f
"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from
{
self
.
page_size
}
to 128."
)
self
.
page_size
=
128
# Dual chunk flash attention backend
if
(
getattr
(
model_config
.
hf_config
,
"dual_chunk_attention_config"
,
None
)
is
not
None
):
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"dual_chunk_flash_attn"
logger
.
info
(
"Dual chunk attention is turned on by default."
)
elif
self
.
attention_backend
!=
"dual_chunk_flash_attn"
:
raise
ValueError
(
"Dual chunk attention is enabled, but attention backend is set to "
f
"
{
self
.
attention_backend
}
. Please set it to 'dual_chunk_flash_attn'."
)
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
"Mixed chunk and radix cache are disabled when using dual-chunk flash attention backend"
)
self
.
enable_mixed_chunk
=
False
self
.
disable_radix_cache
=
True
def
_handle_page_size
(
self
):
def
_handle_page_size
(
self
):
if
self
.
page_size
is
None
:
if
self
.
page_size
is
None
:
self
.
page_size
=
1
self
.
page_size
=
1
...
@@ -1283,6 +1396,24 @@ class ServerArgs:
...
@@ -1283,6 +1396,24 @@ class ServerArgs:
"Page first direct layout only support direct io backend"
"Page first direct layout only support direct io backend"
)
)
if
self
.
enable_hierarchical_cache
and
self
.
hicache_io_backend
==
"kernel"
:
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
if
self
.
decode_attention_backend
is
None
:
if
not
self
.
use_mla_backend
():
self
.
decode_attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
else
:
self
.
decode_attention_backend
=
(
"flashinfer"
if
is_sm100_supported
()
else
"triton"
)
elif
self
.
decode_attention_backend
==
"fa3"
:
self
.
hicache_io_backend
=
"direct"
logger
.
warning
(
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
)
def
_handle_speculative_decoding
(
self
):
def
_handle_speculative_decoding
(
self
):
if
self
.
speculative_algorithm
==
"NEXTN"
:
if
self
.
speculative_algorithm
==
"NEXTN"
:
self
.
speculative_algorithm
=
"EAGLE"
self
.
speculative_algorithm
=
"EAGLE"
...
@@ -3355,19 +3486,34 @@ class ServerArgs:
...
@@ -3355,19 +3486,34 @@ class ServerArgs:
)
)
return
hf_config
return
hf_config
def
get_attention_backends
(
server_args
):
def
get_model_config
(
self
):
# Lazy init to avoid circular import
from
sglang.srt.configs.model_config
import
ModelConfig
if
hasattr
(
self
,
"model_config"
):
return
self
.
model_config
self
.
model_config
=
ModelConfig
.
from_server_args
(
self
)
return
self
.
model_config
def
get_attention_backends
(
self
):
prefill_attention_backend_str
=
(
prefill_attention_backend_str
=
(
se
rver_args
.
prefill_attention_backend
se
lf
.
prefill_attention_backend
if
se
rver_args
.
prefill_attention_backend
if
se
lf
.
prefill_attention_backend
else
se
rver_args
.
attention_backend
else
se
lf
.
attention_backend
)
)
decode_attention_backend_str
=
(
decode_attention_backend_str
=
(
se
rver_args
.
decode_attention_backend
se
lf
.
decode_attention_backend
if
se
rver_args
.
decode_attention_backend
if
se
lf
.
decode_attention_backend
else
se
rver_args
.
attention_backend
else
se
lf
.
attention_backend
)
)
return
prefill_attention_backend_str
,
decode_attention_backend_str
return
prefill_attention_backend_str
,
decode_attention_backend_str
def
use_mla_backend
(
self
):
from
sglang.srt.configs.model_config
import
AttentionArch
model_config
=
self
.
get_model_config
()
return
model_config
.
attention_arch
==
AttentionArch
.
MLA
def
check_server_args
(
self
):
def
check_server_args
(
self
):
# Check parallel size constraints
# Check parallel size constraints
assert
(
assert
(
...
...
python/sglang/srt/utils/common.py
View file @
7b36c47b
...
@@ -2096,80 +2096,80 @@ class MultiprocessingSerializer:
...
@@ -2096,80 +2096,80 @@ class MultiprocessingSerializer:
# Decode base64 string to bytes
# Decode base64 string to bytes
data
=
pybase64
.
b64decode
(
data
,
validate
=
True
)
data
=
pybase64
.
b64decode
(
data
,
validate
=
True
)
class
SafeUnpickler
(
pickle
.
Unpickler
):
ALLOWED_MODULE_PREFIXES
=
{
# --- Python types ---
"builtins."
,
"collections."
,
"copyreg."
,
"functools."
,
"itertools."
,
"operator."
,
"types."
,
"weakref."
,
# --- PyTorch types ---
"torch."
,
"torch._tensor."
,
"torch.storage."
,
"torch.nn.parameter."
,
"torch.autograd.function."
,
# --- torch distributed ---
"torch.distributed."
,
"torch.distributed._shard."
,
"torch.distributed._composable."
,
"torch._C._distributed_c10d."
,
"torch._C._distributed_fsdp."
,
"torch.distributed.optim."
,
# --- multiprocessing ---
"multiprocessing.resource_sharer."
,
"multiprocessing.reduction."
,
"pickletools."
,
# --- PEFT / LoRA ---
"peft."
,
"transformers."
,
"huggingface_hub."
,
# --- SGLang & Unitest ---
"sglang.srt.weight_sync.tensor_bucket."
,
"sglang.srt.model_executor.model_runner."
,
"sglang.srt.layers."
,
"sglang.srt.utils."
,
}
DENY_CLASSES
=
{
(
"builtins"
,
"eval"
),
(
"builtins"
,
"exec"
),
(
"builtins"
,
"compile"
),
(
"os"
,
"system"
),
(
"subprocess"
,
"Popen"
),
(
"subprocess"
,
"run"
),
(
"codecs"
,
"decode"
),
(
"types"
,
"CodeType"
),
(
"types"
,
"FunctionType"
),
}
def
find_class
(
self
,
module
,
name
):
# Block deterministic attacks
if
(
module
,
name
)
in
self
.
DENY_CLASSES
:
raise
RuntimeError
(
f
"Blocked unsafe class loading (
{
module
}
.
{
name
}
), "
f
"to prevent exploitation of CVE-2025-10164"
)
# Allowlist of safe-to-load modules.
if
any
(
(
module
+
"."
).
startswith
(
prefix
)
for
prefix
in
self
.
ALLOWED_MODULE_PREFIXES
):
return
super
().
find_class
(
module
,
name
)
# Block everything else. (Potential attack surface)
raise
RuntimeError
(
f
"Blocked unsafe class loading (
{
module
}
.
{
name
}
), "
f
"to prevent exploitation of CVE-2025-10164"
)
return
SafeUnpickler
(
io
.
BytesIO
(
data
)).
load
()
return
SafeUnpickler
(
io
.
BytesIO
(
data
)).
load
()
class
SafeUnpickler
(
pickle
.
Unpickler
):
ALLOWED_MODULE_PREFIXES
=
{
# --- Python types ---
"builtins."
,
"collections."
,
"copyreg."
,
"functools."
,
"itertools."
,
"operator."
,
"types."
,
"weakref."
,
# --- PyTorch types ---
"torch."
,
"torch._tensor."
,
"torch.storage."
,
"torch.nn.parameter."
,
"torch.autograd.function."
,
# --- torch distributed ---
"torch.distributed."
,
"torch.distributed._shard."
,
"torch.distributed._composable."
,
"torch._C._distributed_c10d."
,
"torch._C._distributed_fsdp."
,
"torch.distributed.optim."
,
# --- multiprocessing ---
"multiprocessing.resource_sharer."
,
"multiprocessing.reduction."
,
"pickletools."
,
# --- PEFT / LoRA ---
"peft."
,
"transformers."
,
"huggingface_hub."
,
# --- SGLang & Unitest ---
"sglang.srt.weight_sync.tensor_bucket."
,
"sglang.srt.model_executor.model_runner."
,
"sglang.srt.layers."
,
"sglang.srt.utils."
,
}
DENY_CLASSES
=
{
(
"builtins"
,
"eval"
),
(
"builtins"
,
"exec"
),
(
"builtins"
,
"compile"
),
(
"os"
,
"system"
),
(
"subprocess"
,
"Popen"
),
(
"subprocess"
,
"run"
),
(
"codecs"
,
"decode"
),
(
"types"
,
"CodeType"
),
(
"types"
,
"FunctionType"
),
}
def
find_class
(
self
,
module
,
name
):
# Block deterministic attacks
if
(
module
,
name
)
in
self
.
DENY_CLASSES
:
raise
RuntimeError
(
f
"Blocked unsafe class loading (
{
module
}
.
{
name
}
), "
f
"to prevent exploitation of CVE-2025-10164"
)
# Allowlist of safe-to-load modules.
if
any
(
(
module
+
"."
).
startswith
(
prefix
)
for
prefix
in
self
.
ALLOWED_MODULE_PREFIXES
):
return
super
().
find_class
(
module
,
name
)
# Block everything else. (Potential attack surface)
raise
RuntimeError
(
f
"Blocked unsafe class loading (
{
module
}
.
{
name
}
), "
f
"to prevent exploitation of CVE-2025-10164"
)
def
debug_timing
(
func
):
def
debug_timing
(
func
):
# todo: replace with a more organized instrumentation
# todo: replace with a more organized instrumentation
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
...
@@ -2620,17 +2620,12 @@ def get_local_ip_auto(fallback: str = None) -> str:
...
@@ -2620,17 +2620,12 @@ def get_local_ip_auto(fallback: str = None) -> str:
raise
ValueError
(
"Can not get local ip"
)
raise
ValueError
(
"Can not get local ip"
)
def
is_page_size_one
(
server_args
):
return
server_args
.
page_size
==
1
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def
is_no_spec_infer_or_topk_one
(
server_args
):
def
is_no_spec_infer_or_topk_one
(
server_args
):
return
server_args
.
speculative_eagle_topk
is
None
or
(
return
server_args
.
speculative_eagle_topk
is
None
or
(
server_args
.
speculative_eagle_topk
is
not
None
server_args
.
speculative_eagle_topk
==
1
and
server_args
.
speculative_eagle_topk
==
1
and
(
server_args
.
page_size
==
1
or
server_args
.
page_size
is
None
)
and
is_page_size_one
(
server_args
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment