Unverified Commit ea338676 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up server args (#10770)

parent b06db198
......@@ -19,8 +19,6 @@ import json
import logging
import os
import random
import socket
import sys
import tempfile
from typing import List, Literal, Optional, Union
......@@ -328,6 +326,10 @@ class ServerArgs:
deepep_config: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
# Hierarchical cache
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
......@@ -398,6 +400,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False
scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None
enable_deterministic_inference: bool = False
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
......@@ -419,15 +422,12 @@ class ServerArgs:
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
# For model weight update
# For model weight update and weight loading
custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False
# Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
......@@ -436,58 +436,84 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Handle deprecated arguments.
self._handle_deprecated_args()
# For deterministic inference
enable_deterministic_inference: bool = False
# Set missing default values.
self._handle_missing_default_values()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Handle memory-related configurations.
self._handle_mem_fraction_static(gpu_mem)
self._handle_chunked_prefill_size(gpu_mem)
# Handle CUDA graph settings.
self._handle_cuda_graph_max_bs(gpu_mem)
# Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
# Apply model-specific adjustments.
self._handle_model_specific_adjustments()
# Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_grammar_backend()
# Handle data parallelism.
self._handle_data_parallelism()
# Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
# Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Handle Hicache settings.
self._handle_hicache()
# Handle speculative decoding logic.
self._handle_speculative_decoding()
# Handle model loading format.
self._handle_load_format()
# Handle PD disaggregation.
self._handle_disaggregation()
# Validate tokenizer settings.
self._handle_tokenizer_batching()
# Propagate environment variables.
self._handle_environment_variables()
# Validate cache settings.
self._handle_cache_compatibility()
# Validate metrics labels.
self._handle_metrics_labels()
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_cutedsl_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
# Handle deterministic inference.
self._handle_deterministic_inference()
# Handle any other necessary validations.
self._handle_other_validations()
def _handle_deprecated_args(self):
if self.enable_ep_moe:
self.ep_size = self.tp_size
print_deprecated_warning(
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
)
if self.enable_deepep_moe:
self.moe_a2a_backend = "deepep"
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutedsl_moe:
self.moe_runner_backend = "flashinfer_cutedsl"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
if self.enable_flashinfer_mxfp4_moe:
self.moe_runner_backend = "flashinfer_mxfp4"
print_deprecated_warning(
"NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
)
pass
def _handle_missing_default_values(self):
if self.tokenizer_path is None:
......@@ -590,6 +616,84 @@ class ServerArgs:
self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch"
def _handle_model_specific_adjustments(self):
if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
return
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch and self.device != "cpu":
assert self.attention_backend in {
"fa3",
"aiter",
"triton",
}, "fa3, aiter, or triton is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def _handle_sampling_backend(self):
if self.sampling_backend is None:
self.sampling_backend = (
......@@ -1014,83 +1118,6 @@ class ServerArgs:
def _handle_other_validations(self):
pass
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Step 1: Handle deprecated arguments.
self._handle_deprecated_args()
# Step 2: Set missing default values.
self._handle_missing_default_values()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Step 3: Handle memory-related configurations.
self._handle_mem_fraction_static(gpu_mem)
self._handle_chunked_prefill_size(gpu_mem)
# Step 4: Handle CUDA graph settings.
self._handle_cuda_graph_max_bs(gpu_mem)
# Step 5: Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
# Step 6: Apply model-specific adjustments.
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
self.model_specific_adjustments()
# Step 7: Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_grammar_backend()
# Step 8: Handle data parallelism.
self._handle_data_parallelism()
# Step 9: Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
# Step 10: Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Step 11: Handle Hicache settings.
self._handle_hicache()
# Step 12: Handle speculative decoding logic.
self._handle_speculative_decoding()
# Step 13: Handle model loading format.
self._handle_load_format()
# Step 14: Handle PD disaggregation.
self._handle_disaggregation()
# Step 15: Validate tokenizer settings.
self._handle_tokenizer_batching()
# Step 16: Propagate environment variables.
self._handle_environment_variables()
# Step 17: Validate cache settings.
self._handle_cache_compatibility()
# Step 18: Validate metrics labels.
self._handle_metrics_labels()
# Step 19: Handle deterministic inference.
self._handle_deterministic_inference()
# Step 20: Handle any other necessary validations.
self._handle_other_validations()
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer
......@@ -1101,24 +1128,6 @@ class ServerArgs:
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
parser.add_argument(
"--tokenizer-path",
type=str,
......@@ -2573,6 +2582,24 @@ class ServerArgs:
action="store_true",
help="Disable mmap while loading weight using safetensors.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
# For PD-Multiplexing
parser.add_argument(
......@@ -2598,38 +2625,38 @@ class ServerArgs:
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
action="store_true",
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
action=DeprecatedAction,
help="NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead.",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
action=DeprecatedAction,
help="NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead.",
)
parser.add_argument(
"--enable-flashinfer-cutedsl-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead.",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead.",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
action=DeprecatedAction,
help="NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead.",
)
@classmethod
......@@ -2862,81 +2889,6 @@ class ServerArgs:
val >= 0 for val in bucket_values
), f"{arg_name} customer rule bucket values should be non-negative"
def model_specific_adjustments(self):
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch and self.device != "cpu":
assert self.attention_backend in {
"fa3",
"aiter",
"triton",
}, "fa3, aiter, or triton is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None)
if vision_config is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment