arg_utils.py 25.6 KB
Newer Older
1
import argparse
2
3
import dataclasses
from dataclasses import dataclass
4
from typing import Optional
5

6
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
7
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
9
                         TokenizerPoolConfig, VisionLanguageConfig)
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.utils import str_to_int_tuple
12
13


14
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
15
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    """Arguments for vLLM engine."""
17
    model: str
18
    tokenizer: Optional[str] = None
19
    tokenizer_mode: str = 'auto'
20
    trust_remote_code: bool = False
21
    download_dir: Optional[str] = None
22
    load_format: str = 'auto'
23
    dtype: str = 'auto'
24
    kv_cache_dtype: str = 'auto'
25
    quantization_param_path: Optional[str] = None
26
    seed: int = 0
27
    max_model_len: Optional[int] = None
28
    worker_use_ray: bool = False
29
30
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
31
    max_parallel_loading_workers: Optional[int] = None
32
    block_size: int = 16
33
    enable_prefix_caching: bool = False
34
    use_v2_block_manager: bool = False
35
    swap_space: int = 4  # GiB
36
    gpu_memory_utilization: float = 0.90
37
    max_num_batched_tokens: Optional[int] = None
38
    max_num_seqs: int = 256
39
    max_logprobs: int = 5  # OpenAI default value
40
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
41
    revision: Optional[str] = None
42
    code_revision: Optional[str] = None
43
    tokenizer_revision: Optional[str] = None
44
    quantization: Optional[str] = None
45
46
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
47
    disable_custom_all_reduce: bool = False
48
49
50
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
51
52
53
54
55
56
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
57
    device: str = 'auto'
58
    ray_workers_use_nsight: bool = False
59
    num_gpu_blocks_override: Optional[int] = None
60
    num_lookahead_slots: int = 0
61
    model_loader_extra_config: Optional[dict] = None
62

63
64
65
66
67
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
68
    scheduler_delay_factor: float = 0.0
69
    enable_chunked_prefill: bool = False
70

71
    guided_decoding_backend: str = 'outlines'
72
73
74
75
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None

76
    def __post_init__(self):
77
78
        if self.tokenizer is None:
            self.tokenizer = self.model
79
80
81

    @staticmethod
    def add_cli_args(
82
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        """Shared CLI arguments for vLLM engine."""
84

85
        # Model arguments
86
87
88
89
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
90
            help='Name or path of the huggingface model to use.')
91
92
93
94
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
95
            help='Name or path of the huggingface tokenizer to use.')
Jasmond L's avatar
Jasmond L committed
96
97
98
99
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
100
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
101
102
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
103
104
105
106
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
107
            help='The specific revision to use for the model code on '
108
109
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
110
111
112
113
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
114
            help='The specific tokenizer version to use. It can be a branch '
115
116
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
117
118
119
120
121
122
123
124
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
125
126
        parser.add_argument('--trust-remote-code',
                            action='store_true',
127
                            help='Trust remote code from huggingface.')
128
129
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
130
                            default=EngineArgs.download_dir,
131
                            help='Directory to download and load the weights, '
132
                            'default to the default cache dir of '
133
                            'huggingface.')
134
135
136
137
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
138
139
140
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
141
142
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
143
            'and fall back to the pytorch bin format if safetensors format '
144
145
146
147
148
149
150
151
152
153
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
            'CoreWeave which assumes tensorizer_uri is set to the location of '
            'the serialized weights.')
154
155
156
157
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
161
162
163
164
165
166
167
168
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
169
170
171
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
172
            choices=['auto', 'fp8'],
173
            default=EngineArgs.kv_cache_dtype,
174
            help='Data type for kv cache storage. If "auto", will use model '
175
176
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
177
            'supported for common inference criteria.')
178
179
180
181
182
183
184
185
186
187
        parser.add_argument(
            '--quantization-param-path',
            type=str,
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
188
            'supported for common inference criteria.')
189
190
        parser.add_argument('--max-model-len',
                            type=int,
191
                            default=EngineArgs.max_model_len,
192
193
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
194
195
196
197
198
199
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
200
            ' (JSON schema / regex etc).')
201
        # Parallel arguments
202
203
        parser.add_argument('--worker-use-ray',
                            action='store_true',
204
205
                            help='Use Ray for distributed serving, will be '
                            'automatically set when using more than 1 GPU.')
206
207
208
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
209
                            default=EngineArgs.pipeline_parallel_size,
210
                            help='Number of pipeline stages.')
211
212
213
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
214
                            default=EngineArgs.tensor_parallel_size,
215
                            help='Number of tensor parallel replicas.')
216
217
218
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
219
            default=EngineArgs.max_parallel_loading_workers,
220
            help='Load model sequentially in multiple batches, '
221
            'to avoid RAM OOM when using tensor '
222
            'parallel and large models.')
223
224
225
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
226
            help='If specified, use nsight to profile Ray workers.')
227
        # KV cache arguments
228
229
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
230
                            default=EngineArgs.block_size,
231
                            choices=[8, 16, 32, 128],
232
233
                            help='Token block size for contiguous chunks of '
                            'tokens.')
234
235
236

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
237
                            help='Enables automatic prefix caching.')
238
239
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
240
                            help='Use BlockSpaceMangerV2.')
241
242
243
244
245
246
247
248
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
249

250
251
252
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
253
                            help='Random seed for operations.')
254
255
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
256
                            default=EngineArgs.swap_space,
257
                            help='CPU swap space size (GiB) per GPU.')
258
259
260
261
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
262
263
264
265
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
266
        parser.add_argument(
267
            '--num-gpu-blocks-override',
268
269
270
271
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
272
273
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
274
                            default=EngineArgs.max_num_batched_tokens,
275
276
                            help='Maximum number of batched tokens per '
                            'iteration.')
277
278
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
279
                            default=EngineArgs.max_num_seqs,
280
                            help='Maximum number of sequences per iteration.')
281
282
283
284
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
285
286
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
287
288
        parser.add_argument('--disable-log-stats',
                            action='store_true',
289
                            help='Disable logging statistics.')
290
291
292
293
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
294
                            choices=[*QUANTIZATION_METHODS, None],
295
                            default=EngineArgs.quantization,
296
297
298
299
300
301
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
302
303
304
305
306
307
308
309
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
310
                            help='Maximum context length covered by CUDA '
311
312
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
313
314
315
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
316
                            help='See ParallelConfig.')
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
                            type=str,
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
369
370
371
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
372
                            choices=["auto", "cuda", "neuron", "cpu"],
373
                            help='Device type for vLLM execution.')
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
            type=str,
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
399
400
401
402
403
404
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
405
406
        parser.add_argument(
            '--enable-chunked-prefill',
407
408
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
409
            'max_num_batched_tokens.')
410
411
412
413
414
415
416
417
418
419
420
421
422

        parser.add_argument(
            '--speculative-model',
            type=str,
            default=None,
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
            default=None,
            help='The number of speculative tokens to sample from '
423
            'the draft model in speculative decoding.')
424
425
426
427
428
429
430
431
432
433

        parser.add_argument('--model-loader-extra-config',
                            type=str,
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

434
        return parser
435
436

    @classmethod
437
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
438
439
440
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
441
442
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
443

444
    def create_engine_config(self, ) -> EngineConfig:
445
        device_config = DeviceConfig(self.device)
446
447
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
448
449
450
451
452
            self.trust_remote_code, self.dtype, self.seed, self.revision,
            self.code_revision, self.tokenizer_revision, self.max_model_len,
            self.quantization, self.quantization_param_path,
            self.enforce_eager, self.max_context_len_to_capture,
            self.max_logprobs)
453
454
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
455
                                   self.swap_space, self.kv_cache_dtype,
456
                                   self.num_gpu_blocks_override,
457
458
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
459
460
461
462
463
464
465
466
467
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
468
469
470
471
472
473
474
475
476

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
        )

477
478
479
480
481
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
482
483
484
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
485
486
487
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
        )
488
489
490
491
492
493
494
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
495

496
497
498
499
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
500
501
        )

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

518
519
520
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

521
522
523
524
525
526
527
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
528
                            speculative_config=speculative_config,
529
530
                            load_config=load_config,
                            decoding_config=decoding_config)
531
532


533
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
534
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
535
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
536
    engine_use_ray: bool = False
537
    disable_log_requests: bool = False
538
    max_log_len: Optional[int] = None
539
540

    @staticmethod
541
542
543
544
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
545
546
        parser.add_argument('--engine-use-ray',
                            action='store_true',
547
                            help='Use Ray to start the LLM engine in a '
548
549
550
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
551
                            help='Disable logging requests.')
552
553
554
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
555
556
557
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
558
        return parser
559
560
561
562
563
564
565
566
567
568


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)