arg_utils.py 20.7 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

6
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
7
8
9
                         ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
                         VisionLanguageConfig)
from vllm.utils import str_to_int_tuple
10
11


12
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
13
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
14
    """Arguments for vLLM engine."""
15
    model: str
16
    tokenizer: Optional[str] = None
17
    tokenizer_mode: str = 'auto'
18
    trust_remote_code: bool = False
19
    download_dir: Optional[str] = None
20
    load_format: str = 'auto'
21
    dtype: str = 'auto'
22
    kv_cache_dtype: str = 'auto'
23
    seed: int = 0
24
    max_model_len: Optional[int] = None
25
    worker_use_ray: bool = False
26
27
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
28
    max_parallel_loading_workers: Optional[int] = None
29
    block_size: int = 16
30
    enable_prefix_caching: bool = False
31
    use_v2_block_manager: bool = False
32
    swap_space: int = 4  # GiB
33
    gpu_memory_utilization: float = 0.90
34
    max_num_batched_tokens: Optional[int] = None
35
    max_num_seqs: int = 256
36
    max_logprobs: int = 5  # OpenAI default value
37
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
38
    revision: Optional[str] = None
39
    code_revision: Optional[str] = None
40
    tokenizer_revision: Optional[str] = None
41
    quantization: Optional[str] = None
42
43
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
44
    disable_custom_all_reduce: bool = False
45
46
47
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
48
49
50
51
52
53
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
54
    device: str = 'auto'
55
    ray_workers_use_nsight: bool = False
56
57
58

    forced_num_gpu_blocks: Optional[int] = None

59
60
61
62
63
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
64
    scheduler_delay_factor: float = 0.0
65

66
    def __post_init__(self):
67
68
        if self.tokenizer is None:
            self.tokenizer = self.model
69
70
71

    @staticmethod
    def add_cli_args(
72
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        """Shared CLI arguments for vLLM engine."""
74
75
76
77

        # NOTE: If you update any of the arguments below, please also
        # make sure to update docs/source/models/engine_args.rst

78
        # Model arguments
79
80
81
82
83
84
85
86
87
88
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
89
90
91
92
93
94
95
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
96
97
98
99
100
101
102
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
            help='the specific revision to use for the model code on '
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
103
104
105
106
107
108
109
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
110
111
        parser.add_argument('--tokenizer-mode',
                            type=str,
112
113
114
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
115
116
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
117
118
119
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
120
121
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
122
                            default=EngineArgs.download_dir,
123
                            help='directory to download and load the weights, '
124
125
                            'default to the default cache dir of '
                            'huggingface')
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
141
142
143
144
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
148
149
150
151
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
152
153
154
155
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
            choices=['auto', 'fp8_e5m2'],
156
            default=EngineArgs.kv_cache_dtype,
157
158
159
            help='Data type for kv cache storage. If "auto", will use model '
            'data type. Note FP8 is not supported when cuda version is '
            'lower than 11.8.')
160
161
        parser.add_argument('--max-model-len',
                            type=int,
162
                            default=EngineArgs.max_model_len,
163
164
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
165
        # Parallel arguments
166
167
        parser.add_argument('--worker-use-ray',
                            action='store_true',
168
                            help='use Ray for distributed serving, will be '
169
170
171
172
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
173
                            default=EngineArgs.pipeline_parallel_size,
174
                            help='number of pipeline stages')
175
176
177
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
178
                            default=EngineArgs.tensor_parallel_size,
179
                            help='number of tensor parallel replicas')
180
181
182
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
183
            default=EngineArgs.max_parallel_loading_workers,
184
185
186
            help='load model sequentially in multiple batches, '
            'to avoid RAM OOM when using tensor '
            'parallel and large models')
187
188
189
190
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
            help='If specified, use nsight to profile ray workers')
191
        # KV cache arguments
192
193
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
194
                            default=EngineArgs.block_size,
195
                            choices=[8, 16, 32, 128],
196
                            help='token block size')
197
198
199
200

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
                            help='Enables automatic prefix caching')
201
202
203
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            help='Use BlockSpaceMangerV2')
204

205
206
207
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
208
                            help='random seed')
209
210
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
211
                            default=EngineArgs.swap_space,
212
                            help='CPU swap space size (GiB) per GPU')
213
214
215
216
217
218
219
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
            help='the fraction of GPU memory to be used for '
            'the model executor, which can range from 0 to 1.'
            'If unspecified, will use the default value of 0.9.')
220
221
222
223
224
225
        parser.add_argument(
            '--forced-num-gpu-blocks',
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
226
227
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
228
                            default=EngineArgs.max_num_batched_tokens,
229
                            help='maximum number of batched tokens per '
230
231
232
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
233
                            default=EngineArgs.max_num_seqs,
234
                            help='maximum number of sequences per iteration')
235
236
237
238
239
240
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
            help=('max number of log probs to return logprobs is specified in'
                  ' SamplingParams'))
241
242
        parser.add_argument('--disable-log-stats',
                            action='store_true',
243
                            help='disable logging statistics')
244
245
246
247
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
CHU Tianxiang's avatar
CHU Tianxiang committed
248
                            choices=['awq', 'gptq', 'squeezellm', None],
249
                            default=EngineArgs.quantization,
250
251
252
253
254
255
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
256
257
258
259
260
261
262
263
264
265
266
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
                            help='maximum context length covered by CUDA '
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
267
268
269
270
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
                            help='See ParallelConfig')
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
                            type=str,
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
323
324
325
326
327
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=["auto", "cuda", "neuron"],
                            help='Device type for vLLM execution.')
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
            type=str,
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
353
354
355
356
357
358
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
359
        return parser
360
361

    @classmethod
362
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
363
364
365
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
366
367
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
368

Zhuohan Li's avatar
Zhuohan Li committed
369
    def create_engine_configs(
370
        self,
371
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
372
373
               DeviceConfig, Optional[LoRAConfig],
               Optional[VisionLanguageConfig]]:
374
        device_config = DeviceConfig(self.device)
375
376
377
378
379
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
            self.trust_remote_code, self.download_dir, self.load_format,
            self.dtype, self.seed, self.revision, self.code_revision,
            self.tokenizer_revision, self.max_model_len, self.quantization,
380
381
            self.enforce_eager, self.max_context_len_to_capture,
            self.max_logprobs)
382
383
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
384
                                   self.swap_space, self.kv_cache_dtype,
385
                                   self.forced_num_gpu_blocks,
386
387
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
388
389
390
391
392
393
394
395
396
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
397
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
398
                                           self.max_num_seqs,
399
                                           model_config.max_model_len,
400
                                           self.use_v2_block_manager,
401
                                           self.scheduler_delay_factor)
402
403
404
405
406
407
408
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

426
        return (model_config, cache_config, parallel_config, scheduler_config,
427
                device_config, lora_config, vision_language_config)
428
429


430
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
431
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
432
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
433
    engine_use_ray: bool = False
434
    disable_log_requests: bool = False
435
    max_log_len: Optional[int] = None
436
437
438

    @staticmethod
    def add_cli_args(
439
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
440
        parser = EngineArgs.add_cli_args(parser)
441
442
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
443
                            help='use Ray to start the LLM engine in a '
444
445
446
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
447
                            help='disable logging requests')
448
449
450
451
452
453
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
454
        return parser