server_args.py 12.2 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
import argparse
import dataclasses
5
import random
6
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
7
8
9
10


@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
11
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
12
13
14
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
17
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
18
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
19
    quantization: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
20
    chat_template: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
25
26
27

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
28
    mem_fraction_static: Optional[float] = None
29
30
    max_prefill_tokens: Optional[int] = None
    max_running_requests: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
31
    schedule_heuristic: str = "lpm"
32
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35

    # Other runtime options
    tp_size: int = 1
36
    stream_interval: int = 8
37
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40

    # Logging
    log_level: str = "info"
41
    log_level_http: Optional[str] = None
42
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
43
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
44

Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
    # Other
    api_key: str = ""

48
49
50
51
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
52
    # Optimization/debug options
53
    disable_flashinfer: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    disable_radix_cache: bool = False
55
    disable_regex_jump_forward: bool = False
56
    disable_cuda_graph: bool = False
57
    disable_disk_cache: bool = False
58
    attention_reduce_in_fp32: bool = False
59
    enable_p2p_check: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
60

61
62
63
64
65
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
Lianmin Zheng's avatar
Lianmin Zheng committed
69
        if self.mem_fraction_static is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
            if self.tp_size >= 8:
                self.mem_fraction_static = 0.80
            elif self.tp_size >= 4:
                self.mem_fraction_static = 0.82
            elif self.tp_size >= 2:
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
76
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
77
                self.mem_fraction_static = 0.90
78
79
80
81
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
82

83
84
85
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
100
101
102
103
104
105
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
106
107
108
109
110
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
111
            help="The additional ports specified for the server.",
112
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114
115
116
117
118
119
120
121
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
139
            "--dtype",
Cody Yu's avatar
Cody Yu committed
140
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
141
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
142
143
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
144
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
145
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
149
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
150
151
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
152
153
154
155
156
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158
159
160
161
162
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
163
164
165
166
167
168
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
            help="The quantization method.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
171
172
173
174
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
175
176
177
178
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
179
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
180
        )
181
        parser.add_argument(
182
            "--max-prefill-tokens",
183
            type=int,
184
            default=ServerArgs.max_prefill_tokens,
185
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
186
        )
187
188
189
190
191
192
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
195
196
        parser.add_argument(
            "--schedule-heuristic",
            type=str,
            default=ServerArgs.schedule_heuristic,
Liangsheng Yin's avatar
Liangsheng Yin committed
197
            choices=["lpm", "random", "fcfs", "dfs-weight"],
198
            help="The scheduling heuristic.",
Lianmin Zheng's avatar
Lianmin Zheng committed
199
        )
200
201
202
203
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
204
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
205
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
206
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
207
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
208
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
209
            default=ServerArgs.tp_size,
210
            help="The tensor parallelism size.",
211
        )
212
213
214
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
215
            default=ServerArgs.stream_interval,
216
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
217
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
218
219
220
221
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
222
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
223
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
228
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
229
        )
230
        parser.add_argument(
231
232
233
234
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
235
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
236
        parser.add_argument(
237
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
238
            action="store_true",
239
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
240
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
243
244
245
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
            help="Show time cost of custom marks",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
246
247
248
249
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
Lianmin Zheng's avatar
Lianmin Zheng committed
250
            help="Set API key of the server",
Liangsheng Yin's avatar
Liangsheng Yin committed
251
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
252

253
254
255
256
257
        # Data parallelism
        parser.add_argument(
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
258
            help="The data parallelism size.",
259
260
261
262
263
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
264
            help="The load balancing strategy for data parallelism.",
265
266
267
268
269
270
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

271
272
273
274
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
275
            help="The nccl init address of multi-node server.",
276
277
        )
        parser.add_argument(
Ying Sheng's avatar
Ying Sheng committed
278
            "--nnodes", type=int, default=1, help="The number of nodes."
279
        )
Ying Sheng's avatar
Ying Sheng committed
280
        parser.add_argument("--node-rank", type=int, help="The node rank.")
281

Lianmin Zheng's avatar
Lianmin Zheng committed
282
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
283
        parser.add_argument(
284
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
285
            action="store_true",
286
            help="Disable flashinfer inference kernels",
Liangsheng Yin's avatar
Liangsheng Yin committed
287
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
288
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
289
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
290
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
291
            help="Disable RadixAttention",
Liangsheng Yin's avatar
Liangsheng Yin committed
292
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
293
        parser.add_argument(
294
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
295
            action="store_true",
Liangsheng Yin's avatar
Liangsheng Yin committed
296
            help="Disable regex jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
297
        )
298
299
300
301
302
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
303
304
305
306
307
        parser.add_argument(
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
310
311
312
313
        parser.add_argument(
            "--attention-reduce-in-fp32",
            action="store_true",
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
            "This only affects Triton attention kernels",
        )
314
315
316
317
318
        parser.add_argument(
            "--enable-p2p-check",
            action="store_true",
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
323
324
325
326
327

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

Lianmin Zheng's avatar
Lianmin Zheng committed
328
    def print_mode_args(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
329
        return (
330
            f"disable_flashinfer={self.disable_flashinfer}, "
331
            f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
332
            f"disable_radix_cache={self.disable_radix_cache}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
333
334
335
336
            f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
            f"disable_disk_cache={self.disable_disk_cache}, "
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
337

338
339
340
@dataclasses.dataclass
class ModelPortArgs:
    nccl_port: int
341
    model_tp_ips: List[str]
342
343
344
    model_tp_ports: List[int]


Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
348
349
@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
    router_port: int
    detokenizer_port: int
350
    model_port_args: List[ModelPortArgs]