""" Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """The arguments of the server.""" import argparse import dataclasses import logging import random from typing import List, Optional, Union from sglang.srt.utils import is_hip, is_ipv6 logger = logging.getLogger(__name__) @dataclasses.dataclass class ServerArgs: # Model and tokenizer model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False load_format: str = "auto" dtype: str = "auto" kv_cache_dtype: str = "auto" trust_remote_code: bool = True context_length: Optional[int] = None quantization: Optional[str] = None served_model_name: Optional[str] = None chat_template: Optional[str] = None is_embedding: bool = False # Port host: str = "127.0.0.1" port: int = 30000 additional_ports: Optional[Union[List[int], int]] = None # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: int = 8192 max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 # Other runtime options tp_size: int = 1 stream_interval: int = 1 random_seed: Optional[int] = None constrained_json_whitespace_pattern: Optional[str] = None # Logging log_level: str = "info" log_level_http: Optional[str] = None log_requests: bool = False show_time_cost: bool = False # Other api_key: Optional[str] = None file_storage_pth: str = "SGLang_storage" # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" # Distributed args dist_init_addr: Optional[str] = None nnodes: int = 1 node_rank: int = 0 # Model override args in JSON json_model_override_args: str = "{}" # Optimization/debug options attention_backend: Optional[str] = None sampling_backend: Optional[str] = None disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False max_torch_compile_bs: int = 32 torchao_config: str = "" enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 def __post_init__(self): # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path if self.served_model_name is None: self.served_model_name = self.model_path if self.chunked_prefill_size <= 0: # Disable chunked prefill self.chunked_prefill_size = None # Mem fraction depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: self.mem_fraction_static = 0.79 elif self.tp_size >= 8: self.mem_fraction_static = 0.83 elif self.tp_size >= 4: self.mem_fraction_static = 0.85 elif self.tp_size >= 2: self.mem_fraction_static = 0.87 else: self.mem_fraction_static = 0.88 if isinstance(self.additional_ports, int): self.additional_ports = [self.additional_ports] elif self.additional_ports is None: self.additional_ports = [] if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) # Deprecation warnings if self.disable_flashinfer: logger.warning( "The option '--disable-flashinfer' will be deprecated in the next release. " "Please use '--attention-backend triton' instead." ) self.attention_backend = "triton" if self.disable_flashinfer_sampling: logger.warning( "The option '--disable-flashinfer-sampling' will be deprecated in the next release. " "Please use '--sampling-backend pytorch' instead. " ) self.sampling_backend = "pytorch" # ROCm: flashinfer available later if is_hip(): self.attention_backend = "triton" self.sampling_backend = "pytorch" # Default kernel backends if self.attention_backend is None: self.attention_backend = "flashinfer" if self.sampling_backend is None: self.sampling_backend = "flashinfer" # Model-specific patches if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: logger.info( "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" ) self.trust_remote_code = False if "gemma-2" in self.model_path.lower(): logger.info("When using sliding window in gemma-2, turn on flashinfer.") self.attention_backend = "flashinfer" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--model-path", type=str, help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", required=True, ) parser.add_argument( "--tokenizer-path", type=str, default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) parser.add_argument( "--host", type=str, default=ServerArgs.host, help="The host of the server." ) parser.add_argument( "--port", type=int, default=ServerArgs.port, help="The port of the server." ) parser.add_argument( "--additional-ports", type=int, nargs="*", default=[], help="The additional ports specified for the server.", ) parser.add_argument( "--tokenizer-mode", type=str, default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " "tokenizer if available, and 'slow' will " "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", action="store_true", help="If set, skip init tokenizer and pass input_ids in generate request", ) parser.add_argument( "--load-format", type=str, default=ServerArgs.load_format, choices=["auto", "pt", "safetensors", "npcache", "dummy"], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " "is not available. " '"pt" will load the weights in the pytorch bin format. ' '"safetensors" will load the weights in the safetensors format. ' '"npcache" will load the weights in pytorch format and store ' "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' "which is mainly for profiling.", ) parser.add_argument( "--dtype", type=str, default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" '* "auto" will use FP16 precision for FP32 and FP16 models, and ' "BF16 precision for BF16 models.\n" '* "half" for FP16. Recommended for AWQ quantization.\n' '* "float16" is the same as "half".\n' '* "bfloat16" for a balance between precision and range.\n' '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", type=str, default=ServerArgs.kv_cache_dtype, choices=["auto", "fp8_e5m2"], help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', ) parser.add_argument( "--trust-remote-code", action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) parser.add_argument( "--context-length", type=int, default=ServerArgs.context_length, help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", ) parser.add_argument( "--quantization", type=str, default=ServerArgs.quantization, choices=[ "awq", "fp8", "gptq", "marlin", "gptq_marlin", "awq_marlin", "squeezellm", "bitsandbytes", ], help="The quantization method.", ) parser.add_argument( "--served-model-name", type=str, default=ServerArgs.served_model_name, help="Override the model name returned by the v1/models endpoint in OpenAI API server.", ) parser.add_argument( "--chat-template", type=str, default=ServerArgs.chat_template, help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", ) parser.add_argument( "--is-embedding", action="store_true", help="Whether to use a CausalLM as an embedding model.", ) parser.add_argument( "--mem-fraction-static", type=float, default=ServerArgs.mem_fraction_static, help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", ) parser.add_argument( "--max-running-requests", type=int, default=ServerArgs.max_running_requests, help="The maximum number of running requests.", ) parser.add_argument( "--max-total-tokens", type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", type=int, default=ServerArgs.chunked_prefill_size, help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill", ) parser.add_argument( "--max-prefill-tokens", type=int, default=ServerArgs.max_prefill_tokens, help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) parser.add_argument( "--schedule-policy", type=str, default=ServerArgs.schedule_policy, choices=["lpm", "random", "fcfs", "dfs-weight"], help="The scheduling policy of the requests.", ) parser.add_argument( "--schedule-conservativeness", type=float, default=ServerArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) parser.add_argument( "--tensor-parallel-size", "--tp-size", type=int, default=ServerArgs.tp_size, help="The tensor parallelism size.", ) parser.add_argument( "--stream-interval", type=int, default=ServerArgs.stream_interval, help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) parser.add_argument( "--random-seed", type=int, default=ServerArgs.random_seed, help="The random seed.", ) parser.add_argument( "--constrained-json-whitespace-pattern", type=str, default=ServerArgs.constrained_json_whitespace_pattern, help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", ) parser.add_argument( "--log-level", type=str, default=ServerArgs.log_level, help="The logging level of all loggers.", ) parser.add_argument( "--log-level-http", type=str, default=ServerArgs.log_level_http, help="The logging level of HTTP server. If not set, reuse --log-level by default.", ) parser.add_argument( "--log-requests", action="store_true", help="Log the inputs and outputs of all requests.", ) parser.add_argument( "--show-time-cost", action="store_true", help="Show time cost of custom marks.", ) parser.add_argument( "--api-key", type=str, default=ServerArgs.api_key, help="Set API key of the server. It is also used in the OpenAI API compatible server.", ) parser.add_argument( "--file-storage-pth", type=str, default=ServerArgs.file_storage_pth, help="The path of the file storage in backend.", ) # Data parallelism parser.add_argument( "--data-parallel-size", "--dp-size", type=int, default=ServerArgs.dp_size, help="The data parallelism size.", ) parser.add_argument( "--load-balance-method", type=str, default=ServerArgs.load_balance_method, help="The load balancing strategy for data parallelism.", choices=[ "round_robin", "shortest_queue", ], ) # Multi-node distributed serving args parser.add_argument( "--dist-init-addr", "--nccl-init-addr", # For backward compatbility. This will be removed in the future. type=str, help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).", ) parser.add_argument( "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." ) parser.add_argument( "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank." ) # Model override args parser.add_argument( "--json-model-override-args", type=str, help="A dictionary in JSON string format used to override default model configurations.", default=ServerArgs.json_model_override_args, ) # Optimization/debug options parser.add_argument( "--attention-backend", type=str, choices=["flashinfer", "triton"], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) parser.add_argument( "--sampling-backend", type=str, choices=["flashinfer", "pytorch"], default=ServerArgs.sampling_backend, help="Choose the kernels for sampling layers.", ) parser.add_argument( "--disable-flashinfer", action="store_true", help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.", ) parser.add_argument( "--disable-flashinfer-sampling", action="store_true", help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.", ) parser.add_argument( "--disable-radix-cache", action="store_true", help="Disable RadixAttention for prefix caching.", ) parser.add_argument( "--disable-regex-jump-forward", action="store_true", help="Disable regex jump-forward.", ) parser.add_argument( "--disable-cuda-graph", action="store_true", help="Disable cuda graph.", ) parser.add_argument( "--disable-cuda-graph-padding", action="store_true", help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.", ) parser.add_argument( "--disable-disk-cache", action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) parser.add_argument( "--disable-custom-all-reduce", action="store_true", default=False, help="Disable the custom all-reduce kernel and fall back to NCCL.", ) parser.add_argument( "--disable-mla", action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) parser.add_argument( "--enable-mixed-chunk", action="store_true", help="Enabling mixing prefill and decode in a batch when using chunked prefill.", ) parser.add_argument( "--enable-torch-compile", action="store_true", help="Optimize the model with torch.compile. Experimental feature.", ) parser.add_argument( "--max-torch-compile-bs", type=int, default=ServerArgs.max_torch_compile_bs, help="Set the maximum batch size when using torch compile.", ) parser.add_argument( "--torchao-config", type=str, default=ServerArgs.torchao_config, help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo", ) parser.add_argument( "--enable-p2p-check", action="store_true", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", ) parser.add_argument( "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) parser.add_argument( "--efficient-weight-load", action="store_true", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) # LoRA options parser.add_argument( "--lora-paths", type=str, nargs="*", default=None, action=LoRAPathAction, help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", ) parser.add_argument( "--max-loras-per-batch", type=int, default=8, help="Maximum number of adapters for a running batch, include base-only request", ) @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) def url(self): if is_ipv6(self.host): return f"http://[{self.host}]:{self.port}" else: return f"http://{self.host}:{self.port}" def check_server_args(self): assert ( self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" assert ( self.max_loras_per_batch > 0 # FIXME and (self.lora_paths is None or self.disable_cuda_graph) and (self.lora_paths is None or self.disable_radix_cache) ), "compatibility of lora and cuda graph and radix attention is in progress" assert self.dp_size == 1, ( "The support for data parallelism is temporarily disabled during refactor. " "Please use sglang<=0.3.2 or wait for later updates." ) def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. Args: args: The command line arguments. Typically, it should be `sys.argv[1:]` to ensure compatibility with `parse_args` when no arguments are passed. Returns: The server arguments. """ parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) raw_args = parser.parse_args(argv) server_args = ServerArgs.from_cli_args(raw_args) return server_args @dataclasses.dataclass class PortArgs: # The port for tokenizer to receive inputs from detokenizer (zmq) tokenizer_port: int # The port for scheduler to receive inputs from tokenizer (zmq) scheduler_port: int # The port for detokenizer to receive inputs from scheduler (zmq) detokenizer_port: int # The port for nccl initialization for multiple TP groups (torch.dist) nccl_ports: List[int] class LoRAPathAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, {}) for lora_path in values: if "=" in lora_path: name, path = lora_path.split("=", 1) getattr(namespace, self.dest)[name] = path else: getattr(namespace, self.dest)[lora_path] = lora_path