arg_utils.py 25 KB
Newer Older
1
import argparse
2
3
import dataclasses
from dataclasses import dataclass
4
from typing import Optional
5

6
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
7
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
9
                         TokenizerPoolConfig, VisionLanguageConfig)
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.utils import str_to_int_tuple
12
13


14
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
15
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    """Arguments for vLLM engine."""
17
    model: str
18
    tokenizer: Optional[str] = None
19
    tokenizer_mode: str = 'auto'
20
    trust_remote_code: bool = False
21
    download_dir: Optional[str] = None
22
    load_format: str = 'auto'
23
    dtype: str = 'auto'
24
    kv_cache_dtype: str = 'auto'
25
    quantization_param_path: Optional[str] = None
26
    seed: int = 0
27
    max_model_len: Optional[int] = None
28
    worker_use_ray: bool = False
29
30
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
31
    max_parallel_loading_workers: Optional[int] = None
32
    block_size: int = 16
33
    enable_prefix_caching: bool = False
34
    use_v2_block_manager: bool = False
35
    swap_space: int = 4  # GiB
36
    gpu_memory_utilization: float = 0.90
37
    max_num_batched_tokens: Optional[int] = None
38
    max_num_seqs: int = 256
39
    max_logprobs: int = 5  # OpenAI default value
40
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
41
    revision: Optional[str] = None
42
    code_revision: Optional[str] = None
43
    tokenizer_revision: Optional[str] = None
44
    quantization: Optional[str] = None
45
46
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
47
    disable_custom_all_reduce: bool = False
48
49
50
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
51
52
53
54
55
56
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
57
    device: str = 'auto'
58
    ray_workers_use_nsight: bool = False
59
    num_gpu_blocks_override: Optional[int] = None
60
    num_lookahead_slots: int = 0
61
    model_loader_extra_config: Optional[dict] = None
62

63
64
65
66
67
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
68
    scheduler_delay_factor: float = 0.0
69
    enable_chunked_prefill: bool = False
70

71
    guided_decoding_backend: str = 'outlines'
72
73
74
75
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None

76
    def __post_init__(self):
77
78
        if self.tokenizer is None:
            self.tokenizer = self.model
79
80
81

    @staticmethod
    def add_cli_args(
82
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        """Shared CLI arguments for vLLM engine."""
84
85
86
87

        # NOTE: If you update any of the arguments below, please also
        # make sure to update docs/source/models/engine_args.rst

88
        # Model arguments
89
90
91
92
93
94
95
96
97
98
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
99
100
101
102
103
104
105
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
106
107
108
109
110
111
112
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
            help='the specific revision to use for the model code on '
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
113
114
115
116
117
118
119
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
120
121
        parser.add_argument('--tokenizer-mode',
                            type=str,
122
123
124
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
125
126
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
127
128
129
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
130
131
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
132
                            default=EngineArgs.download_dir,
133
                            help='directory to download and load the weights, '
134
135
                            'default to the default cache dir of '
                            'huggingface')
136
137
138
139
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
140
141
142
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
143
144
145
146
147
148
149
150
151
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
152
153
154
155
            'which is mainly for profiling.'
            '"tensorizer" will load the weights using tensorizer from CoreWeave'
            'which assumes tensorizer_uri is set to the location of the '
            'serialized weights.')
156
157
158
159
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
163
164
165
166
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
167
168
169
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
170
            choices=['auto', 'fp8'],
171
            default=EngineArgs.kv_cache_dtype,
172
            help='Data type for kv cache storage. If "auto", will use model '
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
            'supported for common inference criteria. ')
        parser.add_argument(
            '--quantization-param-path',
            type=str,
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
            'supported for common inference criteria. ')
187
188
        parser.add_argument('--max-model-len',
                            type=int,
189
                            default=EngineArgs.max_model_len,
190
191
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
192
193
194
195
196
197
198
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
            ' (JSON schema / regex etc)')
199
        # Parallel arguments
200
201
        parser.add_argument('--worker-use-ray',
                            action='store_true',
202
                            help='use Ray for distributed serving, will be '
203
204
205
206
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
207
                            default=EngineArgs.pipeline_parallel_size,
208
                            help='number of pipeline stages')
209
210
211
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
212
                            default=EngineArgs.tensor_parallel_size,
213
                            help='number of tensor parallel replicas')
214
215
216
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
217
            default=EngineArgs.max_parallel_loading_workers,
218
219
220
            help='load model sequentially in multiple batches, '
            'to avoid RAM OOM when using tensor '
            'parallel and large models')
221
222
223
224
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
            help='If specified, use nsight to profile ray workers')
225
        # KV cache arguments
226
227
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
228
                            default=EngineArgs.block_size,
229
                            choices=[8, 16, 32, 128],
230
                            help='token block size')
231
232
233
234

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
                            help='Enables automatic prefix caching')
235
236
237
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            help='Use BlockSpaceMangerV2')
238
239
240
241
242
243
244
245
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
246

247
248
249
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
250
                            help='random seed')
251
252
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
253
                            default=EngineArgs.swap_space,
254
                            help='CPU swap space size (GiB) per GPU')
255
256
257
258
259
260
261
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
            help='the fraction of GPU memory to be used for '
            'the model executor, which can range from 0 to 1.'
            'If unspecified, will use the default value of 0.9.')
262
        parser.add_argument(
263
            '--num-gpu-blocks-override',
264
265
266
267
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
268
269
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
270
                            default=EngineArgs.max_num_batched_tokens,
271
                            help='maximum number of batched tokens per '
272
273
274
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
275
                            default=EngineArgs.max_num_seqs,
276
                            help='maximum number of sequences per iteration')
277
278
279
280
281
282
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
            help=('max number of log probs to return logprobs is specified in'
                  ' SamplingParams'))
283
284
        parser.add_argument('--disable-log-stats',
                            action='store_true',
285
                            help='disable logging statistics')
286
287
288
289
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
290
                            choices=[*QUANTIZATION_METHODS, None],
291
                            default=EngineArgs.quantization,
292
293
294
295
296
297
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
298
299
300
301
302
303
304
305
306
307
308
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
                            help='maximum context length covered by CUDA '
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
309
310
311
312
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
                            help='See ParallelConfig')
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
                            type=str,
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
365
366
367
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
368
                            choices=["auto", "cuda", "neuron", "cpu"],
369
                            help='Device type for vLLM execution.')
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
            type=str,
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
395
396
397
398
399
400
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
401
402
        parser.add_argument(
            '--enable-chunked-prefill',
403
404
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
405
            'max_num_batched_tokens')
406
407
408
409
410
411
412
413
414
415
416
417
418
419

        parser.add_argument(
            '--speculative-model',
            type=str,
            default=None,
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
            default=None,
            help='The number of speculative tokens to sample from '
            'the draft model in speculative decoding')
420
421
422
423
424
425
426
427
428
429

        parser.add_argument('--model-loader-extra-config',
                            type=str,
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

430
        return parser
431
432

    @classmethod
433
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
434
435
436
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
437
438
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
439

440
    def create_engine_config(self, ) -> EngineConfig:
441
        device_config = DeviceConfig(self.device)
442
443
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
444
445
446
447
448
            self.trust_remote_code, self.dtype, self.seed, self.revision,
            self.code_revision, self.tokenizer_revision, self.max_model_len,
            self.quantization, self.quantization_param_path,
            self.enforce_eager, self.max_context_len_to_capture,
            self.max_logprobs)
449
450
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
451
                                   self.swap_space, self.kv_cache_dtype,
452
                                   self.num_gpu_blocks_override,
453
454
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
455
456
457
458
459
460
461
462
463
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
464
465
466
467
468
469
470
471
472

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
        )

473
474
475
476
477
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
478
479
480
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
481
482
483
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
        )
484
485
486
487
488
489
490
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
491

492
493
494
495
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
496
497
        )

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

514
515
516
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

517
518
519
520
521
522
523
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
524
                            speculative_config=speculative_config,
525
526
                            load_config=load_config,
                            decoding_config=decoding_config)
527
528


529
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
530
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
531
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
532
    engine_use_ray: bool = False
533
    disable_log_requests: bool = False
534
    max_log_len: Optional[int] = None
535
536
537

    @staticmethod
    def add_cli_args(
538
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
539
        parser = EngineArgs.add_cli_args(parser)
540
541
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
542
                            help='use Ray to start the LLM engine in a '
543
544
545
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
546
                            help='disable logging requests')
547
548
549
550
551
552
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
553
        return parser