server_args.py 18.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import json
21
import logging
22
import random
23
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
30
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
34
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36
    load_format: str = "auto"
    dtype: str = "auto"
37
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
38
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
40
    quantization: Optional[str] = None
41
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    chat_template: Optional[str] = None
43
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
46
47
48
49
50

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    mem_fraction_static: Optional[float] = None
52
    max_running_requests: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
53
    max_num_reqs: Optional[int] = None
54
    max_total_tokens: Optional[int] = None
55
    chunked_prefill_size: int = 8192
56
    max_prefill_tokens: int = 16384
57
    schedule_policy: str = "lpm"
58
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

    # Other runtime options
    tp_size: int = 1
62
    stream_interval: int = 1
63
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
66

    # Logging
    log_level: str = "info"
67
    log_level_http: Optional[str] = None
68
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
69
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
70

Lianmin Zheng's avatar
Lianmin Zheng committed
71
    # Other
72
    api_key: Optional[str] = None
73
    file_storage_pth: str = "SGLang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
74

75
76
77
78
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
79
    # Optimization/debug options
80
    disable_flashinfer: bool = False
81
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
82
    disable_radix_cache: bool = False
83
    disable_regex_jump_forward: bool = False
84
    disable_cuda_graph: bool = False
85
    disable_cuda_graph_padding: bool = False
86
    disable_disk_cache: bool = False
87
    disable_custom_all_reduce: bool = False
88
    enable_mixed_chunk: bool = False
89
    enable_torch_compile: bool = False
90
    enable_p2p_check: bool = False
91
    enable_mla: bool = False
92
    triton_attention_reduce_in_fp32: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
93

94
95
96
97
98
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

99
100
101
    # Model override args in JSON
    json_model_override_args: Optional[dict] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
102
103
104
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
105
106
107
108

        if self.served_model_name is None:
            self.served_model_name = self.model_path

109
110
111
112
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

Lianmin Zheng's avatar
Lianmin Zheng committed
113
        if self.mem_fraction_static is None:
114
            if self.tp_size >= 16:
115
                self.mem_fraction_static = 0.79
116
            elif self.tp_size >= 8:
117
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
118
            elif self.tp_size >= 4:
119
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            elif self.tp_size >= 2:
121
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
122
            else:
123
                self.mem_fraction_static = 0.88
124

125
126
127
128
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
131
132
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
147
148
149
150
151
152
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
153
154
155
156
157
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
158
            help="The additional ports specified for the server.",
159
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
162
163
164
165
166
167
168
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
169
170
171
172
173
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
191
            "--dtype",
Cody Yu's avatar
Cody Yu committed
192
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
193
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
194
195
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
196
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
197
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
198
199
200
201
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
202
203
            '* "float32" for FP32 precision.',
        )
204
205
206
207
208
209
210
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
213
214
215
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
216
217
218
219
220
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
221
222
223
224
225
226
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
227
228
229
230
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
231
232
233
234
235
236
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
237
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
238
239
240
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
            help="The quantization method.",
        )
243
244
245
246
247
248
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
249
250
251
252
253
254
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
257
258
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
259
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
260
        )
261
262
263
264
265
266
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
267
268
269
        parser.add_argument(
            "--max-num-reqs",
            type=int,
Liangsheng Yin's avatar
Liangsheng Yin committed
270
            default=ServerArgs.max_num_reqs,
Liangsheng Yin's avatar
Liangsheng Yin committed
271
272
            help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
        )
273
274
275
276
277
278
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
        )
279
280
281
282
283
284
285
286
287
288
289
290
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
291
        parser.add_argument(
292
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
293
            type=str,
294
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
295
            choices=["lpm", "random", "fcfs", "dfs-weight"],
296
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
297
        )
298
299
300
301
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
302
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
303
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
304
        parser.add_argument(
305
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
306
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
307
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
308
            default=ServerArgs.tp_size,
309
            help="The tensor parallelism size.",
310
        )
311
312
313
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
314
            default=ServerArgs.stream_interval,
315
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
316
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
319
320
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
321
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
322
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
323
324
325
326
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
327
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
328
        )
329
        parser.add_argument(
330
331
332
333
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
334
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
335
        parser.add_argument(
336
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
337
            action="store_true",
338
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
339
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
342
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
343
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
344
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
345
346
347
348
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
349
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
350
        )
351
352
353
354
355
356
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
357

358
359
        # Data parallelism
        parser.add_argument(
360
            "--data-parallel-size",
361
362
363
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
364
            help="The data parallelism size.",
365
366
367
368
369
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
370
            help="The load balancing strategy for data parallelism.",
371
372
373
374
375
376
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

377
378
379
380
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
381
            help="The nccl init address of multi-node server.",
382
383
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
384
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
385
        )
Ying Sheng's avatar
Ying Sheng committed
386
        parser.add_argument("--node-rank", type=int, help="The node rank.")
387

Lianmin Zheng's avatar
Lianmin Zheng committed
388
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
389
        parser.add_argument(
390
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
391
            action="store_true",
392
393
394
395
396
397
            help="Disable flashinfer attention kernels.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
            help="Disable flashinfer sampling kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
398
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
399
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
400
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
401
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
402
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
403
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
404
        parser.add_argument(
405
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
406
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
407
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
408
        )
409
410
411
412
413
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
414
        parser.add_argument(
415
416
417
418
419
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
420
421
422
423
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
424
425
426
427
428
429
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
430
431
432
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
433
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
434
        )
435
436
437
438
439
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
440
        parser.add_argument(
441
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
442
            action="store_true",
443
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        )
445
446
447
        parser.add_argument(
            "--enable-mla",
            action="store_true",
448
            help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
449
        )
450
        parser.add_argument(
451
            "--triton-attention-reduce-in-fp32",
452
            action="store_true",
453
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
454
            "This only affects Triton attention kernels.",
455
        )
456
457
458
459
460
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
461

462
463
464
465
466
467
468
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
469
470
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
471
472
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
473
474
475
476
477
        args.json_model_override_args = (
            json.loads(args.json_model_override_args)
            if args.json_model_override_args
            else None
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
478
479
480
481
482
483
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

484
485
486
487
488
489
490
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"
491
492
493
494
495
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False
496
        if "gemma-2" in self.model_path.lower():
497
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
498
            self.disable_flashinfer = False
499

Lianmin Zheng's avatar
Lianmin Zheng committed
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    raw_args = parser.parse_args(args)
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
519
520
521
@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
522
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
523
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
524
    nccl_ports: List[int]