server_args.py 12.9 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
import argparse
import dataclasses
5
import random
6
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
7
8
9
10


@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
11
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
12
13
14
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
17
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
18
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
19
    quantization: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
20
    chat_template: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
25
26
27

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
28
    mem_fraction_static: Optional[float] = None
29
30
    max_prefill_tokens: Optional[int] = None
    max_running_requests: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
31
    schedule_heuristic: str = "lpm"
32
    schedule_conservativeness: float = 0.8
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35

    # Other runtime options
    tp_size: int = 1
36
    stream_interval: int = 1
37
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40

    # Logging
    log_level: str = "info"
41
    log_level_http: Optional[str] = None
42
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
43
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
44

Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
    # Other
    api_key: str = ""

48
49
50
51
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
52
    # Optimization/debug options
53
    disable_flashinfer: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    disable_radix_cache: bool = False
55
    disable_regex_jump_forward: bool = False
56
    disable_cuda_graph: bool = False
57
    disable_disk_cache: bool = False
58
    enable_torch_compile: bool = False
59
    attention_reduce_in_fp32: bool = False
60
    enable_p2p_check: bool = False
61
    efficient_weight_load: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
62

63
64
65
66
67
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
Lianmin Zheng's avatar
Lianmin Zheng committed
71
        if self.mem_fraction_static is None:
72
73
74
            if self.tp_size >= 16:
                self.mem_fraction_static = 0.74
            elif self.tp_size >= 8:
75
                self.mem_fraction_static = 0.78
Lianmin Zheng's avatar
Lianmin Zheng committed
76
            elif self.tp_size >= 4:
77
                self.mem_fraction_static = 0.82
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
            elif self.tp_size >= 2:
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
80
            else:
81
                self.mem_fraction_static = 0.88
82
83
84
85
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
86

87
88
89
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
104
105
106
107
108
109
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
110
111
112
113
114
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
115
            help="The additional ports specified for the server.",
116
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
119
120
121
122
123
124
125
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
143
            "--dtype",
Cody Yu's avatar
Cody Yu committed
144
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
145
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
146
147
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
148
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
149
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
150
151
152
153
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
154
155
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
156
157
158
159
160
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
161
162
163
164
165
166
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
171
172
173
174
175
176
177
178
179
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
180
181
            help="The quantization method.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
182
183
184
185
186
187
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
188
189
190
191
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
192
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
193
        )
194
        parser.add_argument(
195
            "--max-prefill-tokens",
196
            type=int,
197
            default=ServerArgs.max_prefill_tokens,
198
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
199
        )
200
201
202
203
204
205
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
208
209
        parser.add_argument(
            "--schedule-heuristic",
            type=str,
            default=ServerArgs.schedule_heuristic,
Liangsheng Yin's avatar
Liangsheng Yin committed
210
            choices=["lpm", "random", "fcfs", "dfs-weight"],
211
            help="The scheduling heuristic.",
Lianmin Zheng's avatar
Lianmin Zheng committed
212
        )
213
214
215
216
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
217
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
218
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
219
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
220
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
221
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
222
            default=ServerArgs.tp_size,
223
            help="The tensor parallelism size.",
224
        )
225
226
227
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
228
            default=ServerArgs.stream_interval,
229
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
230
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
231
232
233
234
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
235
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
236
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
239
240
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
241
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
242
        )
243
        parser.add_argument(
244
245
246
247
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
248
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
249
        parser.add_argument(
250
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
251
            action="store_true",
252
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
253
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
254
255
256
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
257
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
258
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
259
260
261
262
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
Ying Sheng's avatar
Ying Sheng committed
263
            help="Set API key of the server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
264
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
265

266
267
268
269
270
        # Data parallelism
        parser.add_argument(
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
271
            help="The data parallelism size.",
272
273
274
275
276
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
277
            help="The load balancing strategy for data parallelism.",
278
279
280
281
282
283
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

284
285
286
287
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
288
            help="The nccl init address of multi-node server.",
289
290
        )
        parser.add_argument(
Ying Sheng's avatar
Ying Sheng committed
291
            "--nnodes", type=int, default=1, help="The number of nodes."
292
        )
Ying Sheng's avatar
Ying Sheng committed
293
        parser.add_argument("--node-rank", type=int, help="The node rank.")
294

Lianmin Zheng's avatar
Lianmin Zheng committed
295
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
296
        parser.add_argument(
297
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
298
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
299
            help="Disable flashinfer inference kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
300
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
301
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
302
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
303
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
304
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
305
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
306
        parser.add_argument(
307
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
308
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
309
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
310
        )
311
312
313
314
315
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
316
317
318
319
320
        parser.add_argument(
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
321
322
323
324
325
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
326
327
328
329
330
331
        parser.add_argument(
            "--attention-reduce-in-fp32",
            action="store_true",
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
            "This only affects Triton attention kernels",
        )
332
333
334
335
336
        parser.add_argument(
            "--enable-p2p-check",
            action="store_true",
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
        )
337
338
339
340
341
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
344
345
346
347
348
349
350

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

Lianmin Zheng's avatar
Lianmin Zheng committed
351
    def print_mode_args(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
352
        return (
353
            f"disable_flashinfer={self.disable_flashinfer}, "
354
            f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
355
            f"disable_radix_cache={self.disable_radix_cache}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
356
357
358
359
            f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
            f"disable_disk_cache={self.disable_disk_cache}, "
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
362
363

@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
364
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
365
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
366
    nccl_ports: List[int]