arg_utils.py 27.2 KB
Newer Older
1
import argparse
2
3
import dataclasses
from dataclasses import dataclass
4
from typing import Optional
5

6
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
7
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
9
                         TokenizerPoolConfig, VisionLanguageConfig)
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.utils import str_to_int_tuple
12
13


14
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
15
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    """Arguments for vLLM engine."""
17
    model: str
18
    tokenizer: Optional[str] = None
19
    skip_tokenizer_init: bool = False
20
    tokenizer_mode: str = 'auto'
21
    trust_remote_code: bool = False
22
    download_dir: Optional[str] = None
23
    load_format: str = 'auto'
24
    dtype: str = 'auto'
25
    kv_cache_dtype: str = 'auto'
26
    quantization_param_path: Optional[str] = None
27
    seed: int = 0
28
    max_model_len: Optional[int] = None
29
    worker_use_ray: bool = False
30
31
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
32
    max_parallel_loading_workers: Optional[int] = None
33
    block_size: int = 16
34
    enable_prefix_caching: bool = False
35
    use_v2_block_manager: bool = False
36
    swap_space: int = 4  # GiB
37
    gpu_memory_utilization: float = 0.90
38
    max_num_batched_tokens: Optional[int] = None
39
    max_num_seqs: int = 256
40
    max_logprobs: int = 5  # OpenAI default value
41
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
42
    revision: Optional[str] = None
43
    code_revision: Optional[str] = None
44
    tokenizer_revision: Optional[str] = None
45
    quantization: Optional[str] = None
46
47
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
48
    disable_custom_all_reduce: bool = False
49
50
51
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
52
53
54
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
55
    fully_sharded_loras: bool = False
56
57
58
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
59
    device: str = 'auto'
60
    ray_workers_use_nsight: bool = False
61
    num_gpu_blocks_override: Optional[int] = None
62
    num_lookahead_slots: int = 0
63
    model_loader_extra_config: Optional[dict] = None
64

65
66
67
68
69
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
70
    scheduler_delay_factor: float = 0.0
71
    enable_chunked_prefill: bool = False
72

73
    guided_decoding_backend: str = 'outlines'
74
75
76
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
77
    speculative_max_model_len: Optional[int] = None
78

79
    def __post_init__(self):
80
81
        if self.tokenizer is None:
            self.tokenizer = self.model
82
83
84

    @staticmethod
    def add_cli_args(
85
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        """Shared CLI arguments for vLLM engine."""
87

88
        # Model arguments
89
90
91
92
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
93
            help='Name or path of the huggingface model to use.')
94
95
96
97
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
98
            help='Name or path of the huggingface tokenizer to use.')
99
100
101
102
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
103
104
105
106
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
107
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
108
109
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
110
111
112
113
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
114
            help='The specific revision to use for the model code on '
115
116
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
117
118
119
120
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
121
            help='The specific tokenizer version to use. It can be a branch '
122
123
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
124
125
126
127
128
129
130
131
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
132
133
        parser.add_argument('--trust-remote-code',
                            action='store_true',
134
                            help='Trust remote code from huggingface.')
135
136
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
137
                            default=EngineArgs.download_dir,
138
                            help='Directory to download and load the weights, '
139
                            'default to the default cache dir of '
140
                            'huggingface.')
141
142
143
144
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
145
146
147
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
148
149
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
150
            'and fall back to the pytorch bin format if safetensors format '
151
152
153
154
155
156
157
158
159
160
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
            'CoreWeave which assumes tensorizer_uri is set to the location of '
            'the serialized weights.')
161
162
163
164
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166
167
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
168
169
170
171
172
173
174
175
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
176
177
178
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
179
            choices=['auto', 'fp8'],
180
            default=EngineArgs.kv_cache_dtype,
181
            help='Data type for kv cache storage. If "auto", will use model '
182
183
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
184
            'supported for common inference criteria.')
185
186
187
188
189
190
191
192
193
194
        parser.add_argument(
            '--quantization-param-path',
            type=str,
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
195
            'supported for common inference criteria.')
196
197
        parser.add_argument('--max-model-len',
                            type=int,
198
                            default=EngineArgs.max_model_len,
199
200
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
201
202
203
204
205
206
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
207
208
209
210
211
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
212
        # Parallel arguments
213
214
        parser.add_argument('--worker-use-ray',
                            action='store_true',
215
216
                            help='Use Ray for distributed serving, will be '
                            'automatically set when using more than 1 GPU.')
217
218
219
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
220
                            default=EngineArgs.pipeline_parallel_size,
221
                            help='Number of pipeline stages.')
222
223
224
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
225
                            default=EngineArgs.tensor_parallel_size,
226
                            help='Number of tensor parallel replicas.')
227
228
229
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
230
            default=EngineArgs.max_parallel_loading_workers,
231
            help='Load model sequentially in multiple batches, '
232
            'to avoid RAM OOM when using tensor '
233
            'parallel and large models.')
234
235
236
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
237
            help='If specified, use nsight to profile Ray workers.')
238
        # KV cache arguments
239
240
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
241
                            default=EngineArgs.block_size,
242
                            choices=[8, 16, 32],
243
244
                            help='Token block size for contiguous chunks of '
                            'tokens.')
245
246
247

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
248
                            help='Enables automatic prefix caching.')
249
250
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
251
                            help='Use BlockSpaceMangerV2.')
252
253
254
255
256
257
258
259
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
260

261
262
263
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
264
                            help='Random seed for operations.')
265
266
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
267
                            default=EngineArgs.swap_space,
268
                            help='CPU swap space size (GiB) per GPU.')
269
270
271
272
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
273
274
275
276
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
277
        parser.add_argument(
278
            '--num-gpu-blocks-override',
279
280
281
282
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
283
284
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
285
                            default=EngineArgs.max_num_batched_tokens,
286
287
                            help='Maximum number of batched tokens per '
                            'iteration.')
288
289
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
290
                            default=EngineArgs.max_num_seqs,
291
                            help='Maximum number of sequences per iteration.')
292
293
294
295
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
296
297
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
298
299
        parser.add_argument('--disable-log-stats',
                            action='store_true',
300
                            help='Disable logging statistics.')
301
302
303
304
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
305
                            choices=[*QUANTIZATION_METHODS, None],
306
                            default=EngineArgs.quantization,
307
308
309
310
311
312
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
313
314
315
316
317
318
319
320
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
321
                            help='Maximum context length covered by CUDA '
322
323
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
324
325
326
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
327
                            help='See ParallelConfig.')
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
                            type=str,
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
380
381
382
383
384
385
386
387
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
388
389
390
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
391
                            choices=["auto", "cuda", "neuron", "cpu"],
392
                            help='Device type for vLLM execution.')
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
            type=str,
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
418
419
420
421
422
423
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
424
425
        parser.add_argument(
            '--enable-chunked-prefill',
426
427
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
428
            'max_num_batched_tokens.')
429
430
431
432

        parser.add_argument(
            '--speculative-model',
            type=str,
433
            default=EngineArgs.speculative_model,
434
435
436
437
438
439
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
440
            default=EngineArgs.num_speculative_tokens,
441
            help='The number of speculative tokens to sample from '
442
            'the draft model in speculative decoding.')
443

444
445
446
447
448
449
450
451
        parser.add_argument(
            '--speculative-max-model-len',
            type=str,
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

452
453
454
455
456
457
458
459
460
        parser.add_argument('--model-loader-extra-config',
                            type=str,
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

461
        return parser
462
463

    @classmethod
464
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
465
466
467
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
468
469
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
470

471
    def create_engine_config(self, ) -> EngineConfig:
472
        device_config = DeviceConfig(self.device)
473
474
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
475
476
477
478
            self.trust_remote_code, self.dtype, self.seed, self.revision,
            self.code_revision, self.tokenizer_revision, self.max_model_len,
            self.quantization, self.quantization_param_path,
            self.enforce_eager, self.max_context_len_to_capture,
479
            self.max_logprobs, self.skip_tokenizer_init)
480
481
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
482
                                   self.swap_space, self.kv_cache_dtype,
483
                                   self.num_gpu_blocks_override,
484
485
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
486
487
488
489
490
491
492
493
494
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
495
496
497
498
499
500
501

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
502
503
504
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
505
506
        )

507
508
509
510
511
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
512
513
514
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
515
516
517
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
        )
518
519
520
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
521
            fully_sharded_loras=self.fully_sharded_loras,
522
523
524
525
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
526

527
528
529
530
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
531
532
        )

533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

549
550
551
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

552
553
554
555
556
557
558
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
559
                            speculative_config=speculative_config,
560
561
                            load_config=load_config,
                            decoding_config=decoding_config)
562
563


564
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
565
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
566
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
567
    engine_use_ray: bool = False
568
    disable_log_requests: bool = False
569
    max_log_len: Optional[int] = None
570
571

    @staticmethod
572
573
574
575
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
576
577
        parser.add_argument('--engine-use-ray',
                            action='store_true',
578
                            help='Use Ray to start the LLM engine in a '
579
580
581
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
582
                            help='Disable logging requests.')
583
584
585
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
586
587
588
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
589
        return parser
590
591
592
593
594
595
596
597
598
599


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)