arg_utils.py 20.1 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

6
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
7
8
9
                         ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
                         VisionLanguageConfig)
from vllm.utils import str_to_int_tuple
10
11


12
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
13
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
14
    """Arguments for vLLM engine."""
15
    model: str
16
    tokenizer: Optional[str] = None
17
    tokenizer_mode: str = 'auto'
18
    trust_remote_code: bool = False
19
    download_dir: Optional[str] = None
20
    load_format: str = 'auto'
21
    dtype: str = 'auto'
22
    kv_cache_dtype: str = 'auto'
23
    seed: int = 0
24
    max_model_len: Optional[int] = None
25
    worker_use_ray: bool = False
26
27
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
28
    max_parallel_loading_workers: Optional[int] = None
29
    block_size: int = 16
30
    enable_prefix_caching: bool = False
31
    swap_space: int = 4  # GiB
32
    gpu_memory_utilization: float = 0.90
33
    max_num_batched_tokens: Optional[int] = None
34
    max_num_seqs: int = 256
35
    max_logprobs: int = 5  # OpenAI default value
36
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
37
    revision: Optional[str] = None
38
    code_revision: Optional[str] = None
39
    tokenizer_revision: Optional[str] = None
40
    quantization: Optional[str] = None
41
42
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
43
    disable_custom_all_reduce: bool = False
44
45
46
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
47
48
49
50
51
52
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
53
    device: str = 'auto'
54
    ray_workers_use_nsight: bool = False
55
56
57
58
59
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
60
    scheduler_delay_factor: float = 0.0
61

62
    def __post_init__(self):
63
64
        if self.tokenizer is None:
            self.tokenizer = self.model
65
66
67

    @staticmethod
    def add_cli_args(
68
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
69
        """Shared CLI arguments for vLLM engine."""
70
71
72
73

        # NOTE: If you update any of the arguments below, please also
        # make sure to update docs/source/models/engine_args.rst

74
        # Model arguments
75
76
77
78
79
80
81
82
83
84
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
85
86
87
88
89
90
91
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
92
93
94
95
96
97
98
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
            help='the specific revision to use for the model code on '
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
99
100
101
102
103
104
105
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
106
107
        parser.add_argument('--tokenizer-mode',
                            type=str,
108
109
110
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
111
112
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
113
114
115
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
116
117
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
118
                            default=EngineArgs.download_dir,
119
                            help='directory to download and load the weights, '
120
121
                            'default to the default cache dir of '
                            'huggingface')
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
137
138
139
140
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
144
145
146
147
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
148
149
150
151
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
            choices=['auto', 'fp8_e5m2'],
152
            default=EngineArgs.kv_cache_dtype,
153
154
155
            help='Data type for kv cache storage. If "auto", will use model '
            'data type. Note FP8 is not supported when cuda version is '
            'lower than 11.8.')
156
157
        parser.add_argument('--max-model-len',
                            type=int,
158
                            default=EngineArgs.max_model_len,
159
160
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
161
        # Parallel arguments
162
163
        parser.add_argument('--worker-use-ray',
                            action='store_true',
164
                            help='use Ray for distributed serving, will be '
165
166
167
168
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
169
                            default=EngineArgs.pipeline_parallel_size,
170
                            help='number of pipeline stages')
171
172
173
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
174
                            default=EngineArgs.tensor_parallel_size,
175
                            help='number of tensor parallel replicas')
176
177
178
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
179
            default=EngineArgs.max_parallel_loading_workers,
180
181
182
            help='load model sequentially in multiple batches, '
            'to avoid RAM OOM when using tensor '
            'parallel and large models')
183
184
185
186
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
            help='If specified, use nsight to profile ray workers')
187
        # KV cache arguments
188
189
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
190
                            default=EngineArgs.block_size,
191
                            choices=[8, 16, 32, 128],
192
                            help='token block size')
193
194
195
196
197

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
                            help='Enables automatic prefix caching')

198
199
200
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
201
                            help='random seed')
202
203
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
204
                            default=EngineArgs.swap_space,
205
                            help='CPU swap space size (GiB) per GPU')
206
207
208
209
210
211
212
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
            help='the fraction of GPU memory to be used for '
            'the model executor, which can range from 0 to 1.'
            'If unspecified, will use the default value of 0.9.')
213
214
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
215
                            default=EngineArgs.max_num_batched_tokens,
216
                            help='maximum number of batched tokens per '
217
218
219
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
220
                            default=EngineArgs.max_num_seqs,
221
                            help='maximum number of sequences per iteration')
222
223
224
225
226
227
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
            help=('max number of log probs to return logprobs is specified in'
                  ' SamplingParams'))
228
229
        parser.add_argument('--disable-log-stats',
                            action='store_true',
230
                            help='disable logging statistics')
231
232
233
234
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
CHU Tianxiang's avatar
CHU Tianxiang committed
235
                            choices=['awq', 'gptq', 'squeezellm', None],
236
                            default=EngineArgs.quantization,
237
238
239
240
241
242
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
243
244
245
246
247
248
249
250
251
252
253
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
                            help='maximum context length covered by CUDA '
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
254
255
256
257
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
                            help='See ParallelConfig')
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
                            type=str,
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
310
311
312
313
314
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=["auto", "cuda", "neuron"],
                            help='Device type for vLLM execution.')
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
            type=str,
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
340
341
342
343
344
345
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
346
        return parser
347
348

    @classmethod
349
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
350
351
352
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
353
354
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
355

Zhuohan Li's avatar
Zhuohan Li committed
356
    def create_engine_configs(
357
        self,
358
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
359
360
               DeviceConfig, Optional[LoRAConfig],
               Optional[VisionLanguageConfig]]:
361
        device_config = DeviceConfig(self.device)
362
363
364
365
366
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
            self.trust_remote_code, self.download_dir, self.load_format,
            self.dtype, self.seed, self.revision, self.code_revision,
            self.tokenizer_revision, self.max_model_len, self.quantization,
367
368
            self.enforce_eager, self.max_context_len_to_capture,
            self.max_logprobs)
369
370
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
371
                                   self.swap_space, self.kv_cache_dtype,
372
373
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
374
375
376
377
378
379
380
381
382
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
383
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
384
                                           self.max_num_seqs,
385
386
                                           model_config.max_model_len,
                                           self.scheduler_delay_factor)
387
388
389
390
391
392
393
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

411
        return (model_config, cache_config, parallel_config, scheduler_config,
412
                device_config, lora_config, vision_language_config)
413
414


415
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
416
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
417
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
418
    engine_use_ray: bool = False
419
    disable_log_requests: bool = False
420
    max_log_len: Optional[int] = None
421
422
423

    @staticmethod
    def add_cli_args(
424
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
425
        parser = EngineArgs.add_cli_args(parser)
426
427
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
428
                            help='use Ray to start the LLM engine in a '
429
430
431
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
432
                            help='disable logging requests')
433
434
435
436
437
438
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
439
        return parser