server_args.py 17.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
25
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
29
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
33
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
36
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
38
    quantization: Optional[str] = None
39
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
40
    chat_template: Optional[str] = None
41
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
45
46
47
48

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    mem_fraction_static: Optional[float] = None
50
    max_running_requests: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
51
    max_num_reqs: Optional[int] = None
52
    max_total_tokens: Optional[int] = None
53
    chunked_prefill_size: int = 8192
54
    max_prefill_tokens: int = 16384
55
    schedule_policy: str = "lpm"
56
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59

    # Other runtime options
    tp_size: int = 1
60
    stream_interval: int = 1
61
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64

    # Logging
    log_level: str = "info"
65
    log_level_http: Optional[str] = None
66
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
67
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
68

Lianmin Zheng's avatar
Lianmin Zheng committed
69
    # Other
70
    api_key: Optional[str] = None
71
    file_storage_pth: str = "SGLang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
72

73
74
75
76
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
77
    # Optimization/debug options
78
    disable_flashinfer: bool = False
79
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    disable_radix_cache: bool = False
81
    disable_regex_jump_forward: bool = False
82
    disable_cuda_graph: bool = False
83
    disable_cuda_graph_padding: bool = False
84
    disable_disk_cache: bool = False
85
    disable_custom_all_reduce: bool = False
86
    enable_mixed_chunk: bool = False
87
    enable_torch_compile: bool = False
88
    enable_p2p_check: bool = False
89
    enable_mla: bool = False
90
    triton_attention_reduce_in_fp32: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
91

92
93
94
95
96
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
100
101
102
103

        if self.served_model_name is None:
            self.served_model_name = self.model_path

104
105
106
107
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

Lianmin Zheng's avatar
Lianmin Zheng committed
108
        if self.mem_fraction_static is None:
109
            if self.tp_size >= 16:
110
                self.mem_fraction_static = 0.79
111
            elif self.tp_size >= 8:
112
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
113
            elif self.tp_size >= 4:
114
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
115
            elif self.tp_size >= 2:
116
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
117
            else:
118
                self.mem_fraction_static = 0.88
119

120
121
122
123
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
124

125
126
127
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
142
143
144
145
146
147
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
148
149
150
151
152
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
153
            help="The additional ports specified for the server.",
154
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156
157
158
159
160
161
162
163
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
164
165
166
167
168
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
186
            "--dtype",
Cody Yu's avatar
Cody Yu committed
187
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
188
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
189
190
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
191
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
192
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
195
196
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
197
198
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
199
200
201
202
203
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
204
205
206
207
208
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
211
212
213
214
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
215
216
217
218
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
219
220
221
222
223
224
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
225
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
226
227
228
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
229
230
            help="The quantization method.",
        )
231
232
233
234
235
236
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
239
240
241
242
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
245
246
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
247
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
248
        )
249
250
251
252
253
254
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
255
256
257
        parser.add_argument(
            "--max-num-reqs",
            type=int,
Liangsheng Yin's avatar
Liangsheng Yin committed
258
            default=ServerArgs.max_num_reqs,
Liangsheng Yin's avatar
Liangsheng Yin committed
259
260
            help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
        )
261
262
263
264
265
266
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
        )
267
268
269
270
271
272
273
274
275
276
277
278
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
279
        parser.add_argument(
280
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
281
            type=str,
282
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
283
            choices=["lpm", "random", "fcfs", "dfs-weight"],
284
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
285
        )
286
287
288
289
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
290
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
291
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
292
        parser.add_argument(
293
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
294
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
295
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
296
            default=ServerArgs.tp_size,
297
            help="The tensor parallelism size.",
298
        )
299
300
301
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
302
            default=ServerArgs.stream_interval,
303
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
304
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
305
306
307
308
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
309
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
310
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
311
312
313
314
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
315
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
316
        )
317
        parser.add_argument(
318
319
320
321
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
322
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
323
        parser.add_argument(
324
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
325
            action="store_true",
326
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
327
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
328
329
330
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
331
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
332
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
333
334
335
336
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
337
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
338
        )
339
340
341
342
343
344
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
345

346
347
        # Data parallelism
        parser.add_argument(
348
            "--data-parallel-size",
349
350
351
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
352
            help="The data parallelism size.",
353
354
355
356
357
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
358
            help="The load balancing strategy for data parallelism.",
359
360
361
362
363
364
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

365
366
367
368
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
369
            help="The nccl init address of multi-node server.",
370
371
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
372
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
373
        )
Ying Sheng's avatar
Ying Sheng committed
374
        parser.add_argument("--node-rank", type=int, help="The node rank.")
375

Lianmin Zheng's avatar
Lianmin Zheng committed
376
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
377
        parser.add_argument(
378
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
379
            action="store_true",
380
381
382
383
384
385
            help="Disable flashinfer attention kernels.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
            help="Disable flashinfer sampling kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
386
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
387
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
388
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
389
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
390
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
391
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
392
        parser.add_argument(
393
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
394
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
395
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
396
        )
397
398
399
400
401
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
402
        parser.add_argument(
403
404
405
406
407
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
408
409
410
411
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
412
413
414
415
416
417
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
418
419
420
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
421
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
422
        )
423
424
425
426
427
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
428
        parser.add_argument(
429
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
430
            action="store_true",
431
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
432
        )
433
434
435
        parser.add_argument(
            "--enable-mla",
            action="store_true",
436
            help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
437
        )
438
        parser.add_argument(
439
            "--triton-attention-reduce-in-fp32",
440
            action="store_true",
441
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
442
            "This only affects Triton attention kernels.",
443
        )
444
445
446
447
448
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
449
450
451

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
452
453
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
454
455
456
457
458
459
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

460
461
462
463
464
465
466
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"
467
468
469
470
471
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False
472
        if "gemma-2" in self.model_path.lower():
473
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
474
            self.disable_flashinfer = False
475

Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
478
479

@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
480
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
481
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
482
    nccl_ports: List[int]