"vscode:/vscode.git/clone" did not exist on "aa17d78af77616891b42da8b554a050a034d32c4"
Unverified Commit ea338676 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up server args (#10770)

parent b06db198
...@@ -19,8 +19,6 @@ import json ...@@ -19,8 +19,6 @@ import json
import logging import logging
import os import os
import random import random
import socket
import sys
import tempfile import tempfile
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
...@@ -328,6 +326,10 @@ class ServerArgs: ...@@ -328,6 +326,10 @@ class ServerArgs:
deepep_config: Optional[str] = None deepep_config: Optional[str] = None
moe_dense_tp_size: Optional[int] = None moe_dense_tp_size: Optional[int] = None
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
# Hierarchical cache # Hierarchical cache
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
...@@ -398,6 +400,7 @@ class ServerArgs: ...@@ -398,6 +400,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
scheduler_recv_interval: int = 1 scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None numa_node: Optional[List[int]] = None
enable_deterministic_inference: bool = False
# Dynamic batch tokenizer # Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False enable_dynamic_batch_tokenizer: bool = False
...@@ -419,15 +422,12 @@ class ServerArgs: ...@@ -419,15 +422,12 @@ class ServerArgs:
disaggregation_prefill_pp: Optional[int] = 1 disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small # FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1 disaggregation_decode_polling_interval: int = 1
# For model weight update # For model weight update and weight loading
custom_weight_loader: Optional[List[str]] = None custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False weight_loader_disable_mmap: bool = False
# Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
...@@ -436,58 +436,84 @@ class ServerArgs: ...@@ -436,58 +436,84 @@ class ServerArgs:
enable_pdmux: bool = False enable_pdmux: bool = False
sm_group_num: int = 3 sm_group_num: int = 3
# Mamba cache def __post_init__(self):
max_mamba_cache_size: Optional[int] = None """
mamba_ssm_dtype: str = "float32" Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Handle deprecated arguments.
self._handle_deprecated_args()
# For deterministic inference # Set missing default values.
enable_deterministic_inference: bool = False self._handle_missing_default_values()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Handle memory-related configurations.
self._handle_mem_fraction_static(gpu_mem)
self._handle_chunked_prefill_size(gpu_mem)
# Handle CUDA graph settings.
self._handle_cuda_graph_max_bs(gpu_mem)
# Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
# Apply model-specific adjustments.
self._handle_model_specific_adjustments()
# Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_grammar_backend()
# Handle data parallelism.
self._handle_data_parallelism()
# Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
# Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Handle Hicache settings.
self._handle_hicache()
# Handle speculative decoding logic.
self._handle_speculative_decoding()
# Handle model loading format.
self._handle_load_format()
# Handle PD disaggregation.
self._handle_disaggregation()
# Validate tokenizer settings.
self._handle_tokenizer_batching()
# Propagate environment variables.
self._handle_environment_variables()
# Validate cache settings.
self._handle_cache_compatibility()
# Validate metrics labels.
self._handle_metrics_labels()
# Deprecated arguments # Handle deterministic inference.
enable_ep_moe: bool = False self._handle_deterministic_inference()
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False # Handle any other necessary validations.
enable_flashinfer_cutedsl_moe: bool = False self._handle_other_validations()
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
def _handle_deprecated_args(self): def _handle_deprecated_args(self):
if self.enable_ep_moe: pass
self.ep_size = self.tp_size
print_deprecated_warning(
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
)
if self.enable_deepep_moe:
self.moe_a2a_backend = "deepep"
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutedsl_moe:
self.moe_runner_backend = "flashinfer_cutedsl"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
if self.enable_flashinfer_mxfp4_moe:
self.moe_runner_backend = "flashinfer_mxfp4"
print_deprecated_warning(
"NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
)
def _handle_missing_default_values(self): def _handle_missing_default_values(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -590,6 +616,84 @@ class ServerArgs: ...@@ -590,6 +616,84 @@ class ServerArgs:
self.attention_backend = "intel_amx" self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
def _handle_model_specific_adjustments(self):
if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
return
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch and self.device != "cpu":
assert self.attention_backend in {
"fa3",
"aiter",
"triton",
}, "fa3, aiter, or triton is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def _handle_sampling_backend(self): def _handle_sampling_backend(self):
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = ( self.sampling_backend = (
...@@ -1014,83 +1118,6 @@ class ServerArgs: ...@@ -1014,83 +1118,6 @@ class ServerArgs:
def _handle_other_validations(self): def _handle_other_validations(self):
pass pass
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Step 1: Handle deprecated arguments.
self._handle_deprecated_args()
# Step 2: Set missing default values.
self._handle_missing_default_values()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Step 3: Handle memory-related configurations.
self._handle_mem_fraction_static(gpu_mem)
self._handle_chunked_prefill_size(gpu_mem)
# Step 4: Handle CUDA graph settings.
self._handle_cuda_graph_max_bs(gpu_mem)
# Step 5: Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
# Step 6: Apply model-specific adjustments.
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
self.model_specific_adjustments()
# Step 7: Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_grammar_backend()
# Step 8: Handle data parallelism.
self._handle_data_parallelism()
# Step 9: Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
# Step 10: Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Step 11: Handle Hicache settings.
self._handle_hicache()
# Step 12: Handle speculative decoding logic.
self._handle_speculative_decoding()
# Step 13: Handle model loading format.
self._handle_load_format()
# Step 14: Handle PD disaggregation.
self._handle_disaggregation()
# Step 15: Validate tokenizer settings.
self._handle_tokenizer_batching()
# Step 16: Propagate environment variables.
self._handle_environment_variables()
# Step 17: Validate cache settings.
self._handle_cache_compatibility()
# Step 18: Validate metrics labels.
self._handle_metrics_labels()
# Step 19: Handle deterministic inference.
self._handle_deterministic_inference()
# Step 20: Handle any other necessary validations.
self._handle_other_validations()
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer # Model and tokenizer
...@@ -1101,24 +1128,6 @@ class ServerArgs: ...@@ -1101,24 +1128,6 @@ class ServerArgs:
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True, required=True,
) )
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-path", "--tokenizer-path",
type=str, type=str,
...@@ -2573,6 +2582,24 @@ class ServerArgs: ...@@ -2573,6 +2582,24 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable mmap while loading weight using safetensors.", help="Disable mmap while loading weight using safetensors.",
) )
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
# For PD-Multiplexing # For PD-Multiplexing
parser.add_argument( parser.add_argument(
...@@ -2598,38 +2625,38 @@ class ServerArgs: ...@@ -2598,38 +2625,38 @@ class ServerArgs:
# Deprecated arguments # Deprecated arguments
parser.add_argument( parser.add_argument(
"--enable-ep-moe", "--enable-ep-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.", help="NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead.",
) )
parser.add_argument( parser.add_argument(
"--enable-deepep-moe", "--enable-deepep-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.", help="NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-cutlass-moe", "--enable-flashinfer-cutlass-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", help="NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-cutedsl-moe", "--enable-flashinfer-cutedsl-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", help="NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-trtllm-moe", "--enable-flashinfer-trtllm-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP", help="NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead.",
) )
parser.add_argument( parser.add_argument(
"--enable-triton-kernel-moe", "--enable-triton-kernel-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Use triton moe grouped gemm kernel.", help="NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-mxfp4-moe", "--enable-flashinfer-mxfp4-moe",
action="store_true", action=DeprecatedAction,
help="(Deprecated) Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.", help="NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead.",
) )
@classmethod @classmethod
...@@ -2862,81 +2889,6 @@ class ServerArgs: ...@@ -2862,81 +2889,6 @@ class ServerArgs:
val >= 0 for val in bucket_values val >= 0 for val in bucket_values
), f"{arg_name} customer rule bucket values should be non-negative" ), f"{arg_name} customer rule bucket values should be non-negative"
def model_specific_adjustments(self):
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch and self.device != "cpu":
assert self.attention_backend in {
"fa3",
"aiter",
"triton",
}, "fa3, aiter, or triton is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def adjust_mem_fraction_for_vlm(self, model_config): def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None) vision_config = getattr(model_config.hf_config, "vision_config", None)
if vision_config is None: if vision_config is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment