Unverified Commit 7b36c47b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up attention backend selection code & Other minor rename (#12136)

parent 773d89da
...@@ -498,6 +498,11 @@ async def get_server_info(): ...@@ -498,6 +498,11 @@ async def get_server_info():
internal_states: List[Dict[Any, Any]] = ( internal_states: List[Dict[Any, Any]] = (
await _global_state.tokenizer_manager.get_internal_state() await _global_state.tokenizer_manager.get_internal_state()
) )
# This field is not serializable.
if hasattr(_global_state.tokenizer_manager.server_args, "model_config"):
del _global_state.tokenizer_manager.server_args.model_config
return { return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args), **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info, **_global_state.scheduler_info,
......
...@@ -2325,10 +2325,10 @@ class Scheduler( ...@@ -2325,10 +2325,10 @@ class Scheduler(
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.spec_num_total_accepted_tokens = 0 self.spec_num_accepted_tokens = 0
self.spec_num_total_forward_ct = 0 self.spec_num_forward_ct = 0
self.cum_spec_accept_length = 0 self.spec_total_num_accepted_tokens = 0
self.cum_spec_accept_count = 0 self.spec_total_num_forward_ct = 0
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Cache flushed successfully!") logger.info("Cache flushed successfully!")
if_success = True if_success = True
...@@ -2401,13 +2401,16 @@ class Scheduler( ...@@ -2401,13 +2401,16 @@ class Scheduler(
self.tp_worker.model_runner.graph_mem_usage, 2 self.tp_worker.model_runner.graph_mem_usage, 2
) )
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
ret["avg_spec_accept_length"] = ( ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
) )
if RECORD_STEP_TIME: if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict ret["step_time_dict"] = self.step_time_dict
# This field is not serializable.
ret.pop("model_config", None)
return GetInternalStateReqOutput(internal_state=ret) return GetInternalStateReqOutput(internal_state=ret)
def set_internal_state(self, recv_req: SetInternalStateReq): def set_internal_state(self, recv_req: SetInternalStateReq):
...@@ -2434,12 +2437,12 @@ class Scheduler( ...@@ -2434,12 +2437,12 @@ class Scheduler(
if_success = False if_success = False
break break
if if_success: if if_success:
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
avg_spec_accept_length = ( avg_spec_accept_length = (
self.cum_spec_accept_length / self.cum_spec_accept_count self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
) )
logger.info(f"{avg_spec_accept_length=}") logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0 self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0
for k, v in server_args_dict.items(): for k, v in server_args_dict.items():
setattr(get_global_server_args(), k, v) setattr(get_global_server_args(), k, v)
logger.info(f"Global server args updated! {get_global_server_args()=}") logger.info(f"Global server args updated! {get_global_server_args()=}")
......
...@@ -39,10 +39,13 @@ class SchedulerMetricsMixin: ...@@ -39,10 +39,13 @@ class SchedulerMetricsMixin:
self.last_gen_throughput: float = 0.0 self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0 self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0 # The number of accepted tokens and forward ct for the recent `decode_log_interval` batches (for logging)
self.cum_spec_accept_length = 0 self.spec_num_accepted_tokens = 0
self.cum_spec_accept_count = 0 self.spec_num_forward_ct = 0
# The total number of accepted tokens and forward ct for the whole server lifetime
self.spec_total_num_accepted_tokens = 0
self.spec_total_num_forward_ct = 0
self.kv_transfer_speed_gb_s: float = 0.0 self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0 self.kv_transfer_latency_ms: float = 0.0
...@@ -67,8 +70,8 @@ class SchedulerMetricsMixin: ...@@ -67,8 +70,8 @@ class SchedulerMetricsMixin:
) )
def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int): def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs self.spec_num_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs self.spec_num_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
def log_prefill_stats( def log_prefill_stats(
...@@ -253,20 +256,20 @@ class SchedulerMetricsMixin: ...@@ -253,20 +256,20 @@ class SchedulerMetricsMixin:
spec_accept_rate = 0 spec_accept_rate = 0
else: else:
spec_accept_length = ( spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct self.spec_num_accepted_tokens / self.spec_num_forward_ct
) )
# Calculate acceptance rate: accepted tokens / total draft tokens # Calculate acceptance rate: accepted tokens / total draft tokens
total_draft_tokens = self.spec_num_total_forward_ct * ( total_draft_tokens = self.spec_num_forward_ct * (
(self.server_args.speculative_num_steps or 0) + 1 (self.server_args.speculative_num_steps or 0) + 1
) )
spec_accept_rate = ( spec_accept_rate = (
self.spec_num_total_accepted_tokens / total_draft_tokens self.spec_num_accepted_tokens / total_draft_tokens
if total_draft_tokens > 0 if total_draft_tokens > 0
else 0 else 0
) )
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct self.spec_total_num_forward_ct += self.spec_num_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
cache_hit_rate = 0.0 cache_hit_rate = 0.0
......
...@@ -131,13 +131,8 @@ from sglang.srt.utils import ( ...@@ -131,13 +131,8 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_cpu_ids_by_node, get_cpu_ids_by_node,
init_custom_process_group, init_custom_process_group,
is_fa3_default_architecture,
is_flashinfer_available,
is_hip, is_hip,
is_hopper_with_cuda_12_3,
is_no_spec_infer_or_topk_one,
is_npu, is_npu,
is_sm100_supported,
log_info_on_rank0, log_info_on_rank0,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
set_cuda_arch, set_cuda_arch,
...@@ -502,121 +497,6 @@ class ModelRunner: ...@@ -502,121 +497,6 @@ class ModelRunner:
def model_specific_adjustment(self): def model_specific_adjustment(self):
server_args = self.server_args server_args = self.server_args
if (
server_args.attention_backend == "intel_amx"
and server_args.device == "cpu"
and not _is_cpu_amx_available
):
logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args.attention_backend = "torch_native"
if (
server_args.attention_backend == "intel_xpu"
and server_args.device == "xpu"
and not _is_xpu_xmx_available
):
logger.info(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
server_args.attention_backend = "triton"
if server_args.prefill_attention_backend is not None and (
server_args.prefill_attention_backend
== server_args.decode_attention_backend
): # override the default attention backend
server_args.attention_backend = server_args.prefill_attention_backend
if (
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
is not None
):
if server_args.attention_backend is None:
server_args.attention_backend = "dual_chunk_flash_attn"
logger.info("Dual chunk attention is turned on by default.")
elif server_args.attention_backend != "dual_chunk_flash_attn":
raise ValueError(
"Dual chunk attention is enabled, but attention backend is set to "
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
)
if server_args.attention_backend is None:
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 We will use Flashinfer backend on blackwell.
2.3 Otherwise, we will use triton backend.
"""
if not self.use_mla_backend:
# MHA architecture
if (
is_hopper_with_cuda_12_3()
and is_no_spec_infer_or_topk_one(server_args)
and is_fa3_default_architecture(self.model_config.hf_config)
):
server_args.attention_backend = "fa3"
elif _is_hip:
server_args.attention_backend = "aiter"
elif _is_npu:
server_args.attention_backend = "ascend"
else:
server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
# MLA architecture
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
elif is_sm100_supported():
server_args.attention_backend = "flashinfer"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if head_num == 128 or head_num == 16:
server_args.attention_backend = "aiter"
else:
server_args.attention_backend = "triton"
elif _is_npu:
server_args.attention_backend = "ascend"
else:
server_args.attention_backend = "triton"
log_info_on_rank0(
logger,
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
)
elif self.use_mla_backend:
if server_args.device != "cpu":
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
else:
raise ValueError(
f"Invalid attention backend for MLA: {server_args.attention_backend}"
)
else:
if server_args.attention_backend != "intel_amx":
raise ValueError(
"MLA optimization not supported on CPU except for intel_amx backend."
)
if (
server_args.attention_backend == "fa3"
and server_args.kv_cache_dtype == "fp8_e5m2"
):
logger.warning(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args.attention_backend = "triton"
if server_args.enable_double_sparsity: if server_args.enable_double_sparsity:
logger.info( logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph." "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
...@@ -642,31 +522,6 @@ class ModelRunner: ...@@ -642,31 +522,6 @@ class ModelRunner:
if not server_args.disable_chunked_prefix_cache: if not server_args.disable_chunked_prefix_cache:
log_info_on_rank0(logger, "Chunked prefix cache is turned on.") log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
if server_args.attention_backend == "aiter":
if self.model_config.context_len > 8192:
self.mem_fraction_static *= 0.85
if (
server_args.enable_hierarchical_cache
and server_args.hicache_io_backend == "kernel"
):
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
if server_args.decode_attention_backend is None:
if not self.use_mla_backend:
server_args.decode_attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
server_args.decode_attention_backend = (
"flashinfer" if is_sm100_supported() else "triton"
)
elif server_args.decode_attention_backend == "fa3":
server_args.hicache_io_backend = "direct"
logger.warning(
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
)
if self.model_config.hf_config.model_type == "qwen3_vl_moe": if self.model_config.hf_config.model_type == "qwen3_vl_moe":
if ( if (
quantization_config := getattr( quantization_config := getattr(
......
...@@ -34,12 +34,16 @@ from sglang.srt.utils.common import ( ...@@ -34,12 +34,16 @@ from sglang.srt.utils.common import (
LORA_TARGET_ALL_MODULES, LORA_TARGET_ALL_MODULES,
SUPPORTED_LORA_TARGET_MODULES, SUPPORTED_LORA_TARGET_MODULES,
configure_ipv6, configure_ipv6,
cpu_has_amx_support,
get_device, get_device,
get_device_memory_capacity, get_device_memory_capacity,
get_device_sm, get_device_sm,
is_cuda, is_cuda,
is_fa3_default_architecture,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_hopper_with_cuda_12_3,
is_no_spec_infer_or_topk_one,
is_npu, is_npu,
is_port_available, is_port_available,
is_remote_url, is_remote_url,
...@@ -51,6 +55,7 @@ from sglang.srt.utils.common import ( ...@@ -51,6 +55,7 @@ from sglang.srt.utils.common import (
json_list_type, json_list_type,
nullable_str, nullable_str,
parse_connector_type, parse_connector_type,
xpu_has_xmx_support,
) )
from sglang.srt.utils.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.utils.hf_transformers_utils import check_gguf_file, get_config
from sglang.utils import is_in_ci from sglang.utils import is_in_ci
...@@ -545,6 +550,9 @@ class ServerArgs: ...@@ -545,6 +550,9 @@ class ServerArgs:
# Apply model-specific adjustments. # Apply model-specific adjustments.
self._handle_model_specific_adjustments() self._handle_model_specific_adjustments()
# Handle Hicache settings.
self._handle_hicache()
# Set kernel backends. # Set kernel backends.
self._handle_sampling_backend() self._handle_sampling_backend()
self._handle_attention_backend_compatibility() self._handle_attention_backend_compatibility()
...@@ -567,9 +575,6 @@ class ServerArgs: ...@@ -567,9 +575,6 @@ class ServerArgs:
# Handle pipeline parallelism. # Handle pipeline parallelism.
self._handle_pipeline_parallelism() self._handle_pipeline_parallelism()
# Handle Hicache settings.
self._handle_hicache()
# Handle speculative decoding logic. # Handle speculative decoding logic.
self._handle_speculative_decoding() self._handle_speculative_decoding()
...@@ -779,11 +784,9 @@ class ServerArgs: ...@@ -779,11 +784,9 @@ class ServerArgs:
else 0.88 else 0.88
) )
# Lazy init to avoid circular import # Multimodal models need more memory for the image processing,
# Multimodal models need more memory for the image processor # so we adjust the mem_fraction_static accordingly.
from sglang.srt.configs.model_config import ModelConfig model_config = self.get_model_config()
model_config = ModelConfig.from_server_args(self)
if model_config.is_multimodal: if model_config.is_multimodal:
self.adjust_mem_fraction_for_vlm(model_config) self.adjust_mem_fraction_for_vlm(model_config)
...@@ -1042,6 +1045,67 @@ class ServerArgs: ...@@ -1042,6 +1045,67 @@ class ServerArgs:
) )
def _handle_attention_backend_compatibility(self): def _handle_attention_backend_compatibility(self):
model_config = self.get_model_config()
use_mla_backend = self.use_mla_backend()
if self.prefill_attention_backend is not None and (
self.prefill_attention_backend == self.decode_attention_backend
): # override the default attention backend
self.attention_backend = self.prefill_attention_backend
# Pick the default attention backend if not specified
if self.attention_backend is None:
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 We will use Flashinfer backend on blackwell.
2.3 Otherwise, we will use triton backend.
"""
if not use_mla_backend:
# MHA architecture
if (
is_hopper_with_cuda_12_3()
and is_no_spec_infer_or_topk_one(self)
and is_fa3_default_architecture(self.model_config.hf_config)
):
self.attention_backend = "fa3"
elif is_hip():
self.attention_backend = "aiter"
elif is_npu():
self.attention_backend = "ascend"
else:
self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
# MLA architecture
if is_hopper_with_cuda_12_3():
self.attention_backend = "fa3"
elif is_sm100_supported():
self.attention_backend = "flashinfer"
elif is_hip():
head_num = model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if head_num == 128 or head_num == 16:
self.attention_backend = "aiter"
else:
self.attention_backend = "triton"
elif is_npu():
self.attention_backend = "ascend"
else:
self.attention_backend = "triton"
logger.warning(
f"Attention backend not explicitly specified. Use {self.attention_backend} backend by default."
)
# Torch native and flex attention backends
if self.attention_backend == "torch_native": if self.attention_backend == "torch_native":
logger.warning( logger.warning(
"Cuda graph is disabled because of using torch native attention backend" "Cuda graph is disabled because of using torch native attention backend"
...@@ -1057,12 +1121,7 @@ class ServerArgs: ...@@ -1057,12 +1121,7 @@ class ServerArgs:
self.speculative_algorithm is None self.speculative_algorithm is None
), "Speculative decoding is currently not supported with Flex Attention backend" ), "Speculative decoding is currently not supported with Flex Attention backend"
if is_npu() and self.attention_backend in ["ascend"]: # Major NVIDIA platforms backends
logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
self.page_size = 128
if ( if (
self.attention_backend == "flashmla" self.attention_backend == "flashmla"
or self.decode_attention_backend == "flashmla" or self.decode_attention_backend == "flashmla"
...@@ -1117,19 +1176,13 @@ class ServerArgs: ...@@ -1117,19 +1176,13 @@ class ServerArgs:
) )
self.page_size = 64 self.page_size = 64
if self.attention_backend == "dual_chunk_flash_attn": if self.attention_backend == "fa3" and self.kv_cache_dtype == "fp8_e5m2":
logger.warning( logger.warning(
"Mixed chunk and radix cache are disabled when using dual-chunk flash attention backend" "FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
) )
self.enable_mixed_chunk = False self.attention_backend = "triton"
self.disable_radix_cache = True
if self.attention_backend == "intel_xpu":
if self.page_size not in [32, 64, 128]:
logger.warning(
f"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from {self.page_size} to 128."
)
self.page_size = 128
if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4": if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4":
raise ValueError( raise ValueError(
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead." "FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."
...@@ -1140,6 +1193,66 @@ class ServerArgs: ...@@ -1140,6 +1193,66 @@ class ServerArgs:
) )
self.page_size = 128 self.page_size = 128
# AMD platforms backends
if self.attention_backend == "aiter":
if model_config.context_len > 8192:
self.mem_fraction_static *= 0.90
# NPU platforms backends
if is_npu() and self.attention_backend in ["ascend"]:
logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
self.page_size = 128
# Other platforms backends
if (
self.attention_backend == "intel_amx"
and self.device == "cpu"
and not cpu_has_amx_support()
):
logger.warning(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
self.attention_backend = "torch_native"
if (
self.attention_backend == "intel_xpu"
and self.device == "xpu"
and not xpu_has_xmx_support()
):
logger.warning(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
self.attention_backend = "triton"
if self.attention_backend == "intel_xpu":
if self.page_size not in [32, 64, 128]:
logger.warning(
f"Intel XPU attention backend only supports page_size of 32, 64 or 128, changing page_size from {self.page_size} to 128."
)
self.page_size = 128
# Dual chunk flash attention backend
if (
getattr(model_config.hf_config, "dual_chunk_attention_config", None)
is not None
):
if self.attention_backend is None:
self.attention_backend = "dual_chunk_flash_attn"
logger.info("Dual chunk attention is turned on by default.")
elif self.attention_backend != "dual_chunk_flash_attn":
raise ValueError(
"Dual chunk attention is enabled, but attention backend is set to "
f"{self.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
)
if self.attention_backend == "dual_chunk_flash_attn":
logger.warning(
"Mixed chunk and radix cache are disabled when using dual-chunk flash attention backend"
)
self.enable_mixed_chunk = False
self.disable_radix_cache = True
def _handle_page_size(self): def _handle_page_size(self):
if self.page_size is None: if self.page_size is None:
self.page_size = 1 self.page_size = 1
...@@ -1283,6 +1396,24 @@ class ServerArgs: ...@@ -1283,6 +1396,24 @@ class ServerArgs:
"Page first direct layout only support direct io backend" "Page first direct layout only support direct io backend"
) )
if self.enable_hierarchical_cache and self.hicache_io_backend == "kernel":
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
if self.decode_attention_backend is None:
if not self.use_mla_backend():
self.decode_attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
self.decode_attention_backend = (
"flashinfer" if is_sm100_supported() else "triton"
)
elif self.decode_attention_backend == "fa3":
self.hicache_io_backend = "direct"
logger.warning(
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
)
def _handle_speculative_decoding(self): def _handle_speculative_decoding(self):
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
self.speculative_algorithm = "EAGLE" self.speculative_algorithm = "EAGLE"
...@@ -3355,19 +3486,34 @@ class ServerArgs: ...@@ -3355,19 +3486,34 @@ class ServerArgs:
) )
return hf_config return hf_config
def get_attention_backends(server_args): def get_model_config(self):
# Lazy init to avoid circular import
from sglang.srt.configs.model_config import ModelConfig
if hasattr(self, "model_config"):
return self.model_config
self.model_config = ModelConfig.from_server_args(self)
return self.model_config
def get_attention_backends(self):
prefill_attention_backend_str = ( prefill_attention_backend_str = (
server_args.prefill_attention_backend self.prefill_attention_backend
if server_args.prefill_attention_backend if self.prefill_attention_backend
else server_args.attention_backend else self.attention_backend
) )
decode_attention_backend_str = ( decode_attention_backend_str = (
server_args.decode_attention_backend self.decode_attention_backend
if server_args.decode_attention_backend if self.decode_attention_backend
else server_args.attention_backend else self.attention_backend
) )
return prefill_attention_backend_str, decode_attention_backend_str return prefill_attention_backend_str, decode_attention_backend_str
def use_mla_backend(self):
from sglang.srt.configs.model_config import AttentionArch
model_config = self.get_model_config()
return model_config.attention_arch == AttentionArch.MLA
def check_server_args(self): def check_server_args(self):
# Check parallel size constraints # Check parallel size constraints
assert ( assert (
......
...@@ -2096,80 +2096,80 @@ class MultiprocessingSerializer: ...@@ -2096,80 +2096,80 @@ class MultiprocessingSerializer:
# Decode base64 string to bytes # Decode base64 string to bytes
data = pybase64.b64decode(data, validate=True) data = pybase64.b64decode(data, validate=True)
class SafeUnpickler(pickle.Unpickler):
ALLOWED_MODULE_PREFIXES = {
# --- Python types ---
"builtins.",
"collections.",
"copyreg.",
"functools.",
"itertools.",
"operator.",
"types.",
"weakref.",
# --- PyTorch types ---
"torch.",
"torch._tensor.",
"torch.storage.",
"torch.nn.parameter.",
"torch.autograd.function.",
# --- torch distributed ---
"torch.distributed.",
"torch.distributed._shard.",
"torch.distributed._composable.",
"torch._C._distributed_c10d.",
"torch._C._distributed_fsdp.",
"torch.distributed.optim.",
# --- multiprocessing ---
"multiprocessing.resource_sharer.",
"multiprocessing.reduction.",
"pickletools.",
# --- PEFT / LoRA ---
"peft.",
"transformers.",
"huggingface_hub.",
# --- SGLang & Unitest ---
"sglang.srt.weight_sync.tensor_bucket.",
"sglang.srt.model_executor.model_runner.",
"sglang.srt.layers.",
"sglang.srt.utils.",
}
DENY_CLASSES = {
("builtins", "eval"),
("builtins", "exec"),
("builtins", "compile"),
("os", "system"),
("subprocess", "Popen"),
("subprocess", "run"),
("codecs", "decode"),
("types", "CodeType"),
("types", "FunctionType"),
}
def find_class(self, module, name):
# Block deterministic attacks
if (module, name) in self.DENY_CLASSES:
raise RuntimeError(
f"Blocked unsafe class loading ({module}.{name}), "
f"to prevent exploitation of CVE-2025-10164"
)
# Allowlist of safe-to-load modules.
if any(
(module + ".").startswith(prefix)
for prefix in self.ALLOWED_MODULE_PREFIXES
):
return super().find_class(module, name)
# Block everything else. (Potential attack surface)
raise RuntimeError(
f"Blocked unsafe class loading ({module}.{name}), "
f"to prevent exploitation of CVE-2025-10164"
)
return SafeUnpickler(io.BytesIO(data)).load() return SafeUnpickler(io.BytesIO(data)).load()
class SafeUnpickler(pickle.Unpickler):
ALLOWED_MODULE_PREFIXES = {
# --- Python types ---
"builtins.",
"collections.",
"copyreg.",
"functools.",
"itertools.",
"operator.",
"types.",
"weakref.",
# --- PyTorch types ---
"torch.",
"torch._tensor.",
"torch.storage.",
"torch.nn.parameter.",
"torch.autograd.function.",
# --- torch distributed ---
"torch.distributed.",
"torch.distributed._shard.",
"torch.distributed._composable.",
"torch._C._distributed_c10d.",
"torch._C._distributed_fsdp.",
"torch.distributed.optim.",
# --- multiprocessing ---
"multiprocessing.resource_sharer.",
"multiprocessing.reduction.",
"pickletools.",
# --- PEFT / LoRA ---
"peft.",
"transformers.",
"huggingface_hub.",
# --- SGLang & Unitest ---
"sglang.srt.weight_sync.tensor_bucket.",
"sglang.srt.model_executor.model_runner.",
"sglang.srt.layers.",
"sglang.srt.utils.",
}
DENY_CLASSES = {
("builtins", "eval"),
("builtins", "exec"),
("builtins", "compile"),
("os", "system"),
("subprocess", "Popen"),
("subprocess", "run"),
("codecs", "decode"),
("types", "CodeType"),
("types", "FunctionType"),
}
def find_class(self, module, name):
# Block deterministic attacks
if (module, name) in self.DENY_CLASSES:
raise RuntimeError(
f"Blocked unsafe class loading ({module}.{name}), "
f"to prevent exploitation of CVE-2025-10164"
)
# Allowlist of safe-to-load modules.
if any(
(module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES
):
return super().find_class(module, name)
# Block everything else. (Potential attack surface)
raise RuntimeError(
f"Blocked unsafe class loading ({module}.{name}), "
f"to prevent exploitation of CVE-2025-10164"
)
def debug_timing(func): def debug_timing(func):
# todo: replace with a more organized instrumentation # todo: replace with a more organized instrumentation
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -2620,17 +2620,12 @@ def get_local_ip_auto(fallback: str = None) -> str: ...@@ -2620,17 +2620,12 @@ def get_local_ip_auto(fallback: str = None) -> str:
raise ValueError("Can not get local ip") raise ValueError("Can not get local ip")
def is_page_size_one(server_args):
return server_args.page_size == 1
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1. # TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1. # TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def is_no_spec_infer_or_topk_one(server_args): def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or ( return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None server_args.speculative_eagle_topk == 1
and server_args.speculative_eagle_topk == 1 and (server_args.page_size == 1 or server_args.page_size is None)
and is_page_size_one(server_args)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment