arg_utils.py 16.5 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

6
7
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, LoRAConfig)
8
9


10
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
11
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    """Arguments for vLLM engine."""
13
    model: str
14
    tokenizer: Optional[str] = None
15
    tokenizer_mode: str = 'auto'
16
    trust_remote_code: bool = False
17
    download_dir: Optional[str] = None
18
    load_format: str = 'auto'
19
    dtype: str = 'auto'
20
    kv_cache_dtype: str = 'auto'
21
    seed: int = 0
22
    max_model_len: Optional[int] = None
23
    worker_use_ray: bool = False
24
25
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
26
    max_parallel_loading_workers: Optional[int] = None
27
    block_size: int = 16
28
    enable_prefix_caching: bool = False
29
    swap_space: int = 4  # GiB
30
    gpu_memory_utilization: float = 0.90
31
    max_num_batched_tokens: Optional[int] = None
32
    max_num_seqs: int = 256
33
    max_paddings: int = 256
34
    max_logprobs: int = 5  # OpenAI default value
35
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
36
    revision: Optional[str] = None
37
    code_revision: Optional[str] = None
38
    tokenizer_revision: Optional[str] = None
39
    quantization: Optional[str] = None
40
41
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
42
    disable_custom_all_reduce: bool = False
43
44
45
46
47
48
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
49
    device: str = 'auto'
50
    ray_workers_use_nsight: bool = False
51

52
    def __post_init__(self):
53
54
        if self.tokenizer is None:
            self.tokenizer = self.model
55
56
57

    @staticmethod
    def add_cli_args(
58
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
59
        """Shared CLI arguments for vLLM engine."""
60
61
62
63

        # NOTE: If you update any of the arguments below, please also
        # make sure to update docs/source/models/engine_args.rst

64
        # Model arguments
65
66
67
68
69
70
71
72
73
74
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
75
76
77
78
79
80
81
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
82
83
84
85
86
87
88
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
            help='the specific revision to use for the model code on '
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
89
90
91
92
93
94
95
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
96
97
        parser.add_argument('--tokenizer-mode',
                            type=str,
98
99
100
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
101
102
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
103
104
105
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
106
107
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
108
                            default=EngineArgs.download_dir,
109
                            help='directory to download and load the weights, '
110
111
                            'default to the default cache dir of '
                            'huggingface')
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
127
128
129
130
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
134
135
136
137
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
138
139
140
141
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
            choices=['auto', 'fp8_e5m2'],
142
            default=EngineArgs.kv_cache_dtype,
143
144
145
            help='Data type for kv cache storage. If "auto", will use model '
            'data type. Note FP8 is not supported when cuda version is '
            'lower than 11.8.')
146
147
        parser.add_argument('--max-model-len',
                            type=int,
148
                            default=EngineArgs.max_model_len,
149
150
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
151
        # Parallel arguments
152
153
        parser.add_argument('--worker-use-ray',
                            action='store_true',
154
                            help='use Ray for distributed serving, will be '
155
156
157
158
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
159
                            default=EngineArgs.pipeline_parallel_size,
160
                            help='number of pipeline stages')
161
162
163
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
164
                            default=EngineArgs.tensor_parallel_size,
165
                            help='number of tensor parallel replicas')
166
167
168
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
169
            default=EngineArgs.max_parallel_loading_workers,
170
171
172
            help='load model sequentially in multiple batches, '
            'to avoid RAM OOM when using tensor '
            'parallel and large models')
173
174
175
176
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
            help='If specified, use nsight to profile ray workers')
177
        # KV cache arguments
178
179
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
180
                            default=EngineArgs.block_size,
181
                            choices=[8, 16, 32, 128],
182
                            help='token block size')
183
184
185
186
187

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
                            help='Enables automatic prefix caching')

188
189
190
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
191
                            help='random seed')
192
193
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
194
                            default=EngineArgs.swap_space,
195
                            help='CPU swap space size (GiB) per GPU')
196
197
198
199
200
201
202
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
            help='the fraction of GPU memory to be used for '
            'the model executor, which can range from 0 to 1.'
            'If unspecified, will use the default value of 0.9.')
203
204
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
205
                            default=EngineArgs.max_num_batched_tokens,
206
                            help='maximum number of batched tokens per '
207
208
209
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
210
                            default=EngineArgs.max_num_seqs,
211
                            help='maximum number of sequences per iteration')
212
213
214
215
        parser.add_argument('--max-paddings',
                            type=int,
                            default=EngineArgs.max_paddings,
                            help='maximum number of paddings in a batch')
216
217
218
219
220
221
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
            help=('max number of log probs to return logprobs is specified in'
                  ' SamplingParams'))
222
223
        parser.add_argument('--disable-log-stats',
                            action='store_true',
224
                            help='disable logging statistics')
225
226
227
228
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
CHU Tianxiang's avatar
CHU Tianxiang committed
229
                            choices=['awq', 'gptq', 'squeezellm', None],
230
                            default=EngineArgs.quantization,
231
232
233
234
235
236
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
237
238
239
240
241
242
243
244
245
246
247
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
                            help='maximum context length covered by CUDA '
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
248
249
250
251
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
                            help='See ParallelConfig')
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
285
286
287
288
289
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=["auto", "cuda", "neuron"],
                            help='Device type for vLLM execution.')
290
        return parser
291
292

    @classmethod
293
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
294
295
296
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
297
298
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
299

Zhuohan Li's avatar
Zhuohan Li committed
300
    def create_engine_configs(
301
        self,
302
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
303
304
               DeviceConfig, Optional[LoRAConfig]]:
        device_config = DeviceConfig(self.device)
305
306
307
308
309
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
            self.trust_remote_code, self.download_dir, self.load_format,
            self.dtype, self.seed, self.revision, self.code_revision,
            self.tokenizer_revision, self.max_model_len, self.quantization,
310
311
            self.enforce_eager, self.max_context_len_to_capture,
            self.max_logprobs)
312
313
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
314
                                   self.swap_space, self.kv_cache_dtype,
315
316
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
317
318
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
319
                                         self.worker_use_ray,
320
                                         self.max_parallel_loading_workers,
321
322
                                         self.disable_custom_all_reduce,
                                         self.ray_workers_use_nsight)
323
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
324
                                           self.max_num_seqs,
325
326
                                           model_config.max_model_len,
                                           self.max_paddings)
327
328
329
330
331
332
333
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
334
335
        return (model_config, cache_config, parallel_config, scheduler_config,
                device_config, lora_config)
336
337


338
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
339
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
340
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
341
    engine_use_ray: bool = False
342
    disable_log_requests: bool = False
343
    max_log_len: Optional[int] = None
344
345
346

    @staticmethod
    def add_cli_args(
347
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
348
        parser = EngineArgs.add_cli_args(parser)
349
350
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
351
                            help='use Ray to start the LLM engine in a '
352
353
354
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
355
                            help='disable logging requests')
356
357
358
359
360
361
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
362
        return parser