"vscode:/vscode.git/clone" did not exist on "2a75c6bc91fc39a0b35b6ead9bee8f05ee3e420a"
arg_utils.py 17.8 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

6
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
7
8
                         ParallelConfig, SchedulerConfig, LoRAConfig,
                         TokenizerPoolConfig)
9
10


11
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
12
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
13
    """Arguments for vLLM engine."""
14
    model: str
15
    tokenizer: Optional[str] = None
16
    tokenizer_mode: str = 'auto'
17
    trust_remote_code: bool = False
18
    download_dir: Optional[str] = None
19
    load_format: str = 'auto'
20
    dtype: str = 'auto'
21
    kv_cache_dtype: str = 'auto'
22
    seed: int = 0
23
    max_model_len: Optional[int] = None
24
    worker_use_ray: bool = False
25
26
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
27
    max_parallel_loading_workers: Optional[int] = None
28
    block_size: int = 16
29
    enable_prefix_caching: bool = False
30
    swap_space: int = 4  # GiB
31
    gpu_memory_utilization: float = 0.90
32
    max_num_batched_tokens: Optional[int] = None
33
    max_num_seqs: int = 256
34
    max_logprobs: int = 5  # OpenAI default value
35
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
36
    revision: Optional[str] = None
37
    code_revision: Optional[str] = None
38
    tokenizer_revision: Optional[str] = None
39
    quantization: Optional[str] = None
40
41
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
42
    disable_custom_all_reduce: bool = False
43
44
45
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
46
47
48
49
50
51
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
52
    device: str = 'auto'
53
    ray_workers_use_nsight: bool = False
54
    scheduler_delay_factor: float = 0.0
55

56
    def __post_init__(self):
57
58
        if self.tokenizer is None:
            self.tokenizer = self.model
59
60
61

    @staticmethod
    def add_cli_args(
62
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
        """Shared CLI arguments for vLLM engine."""
64
65
66
67

        # NOTE: If you update any of the arguments below, please also
        # make sure to update docs/source/models/engine_args.rst

68
        # Model arguments
69
70
71
72
73
74
75
76
77
78
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
79
80
81
82
83
84
85
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
86
87
88
89
90
91
92
        parser.add_argument(
            '--code-revision',
            type=str,
            default=None,
            help='the specific revision to use for the model code on '
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
93
94
95
96
97
98
99
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
100
101
        parser.add_argument('--tokenizer-mode',
                            type=str,
102
103
104
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
105
106
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
107
108
109
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
110
111
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
112
                            default=EngineArgs.download_dir,
113
                            help='directory to download and load the weights, '
114
115
                            'default to the default cache dir of '
                            'huggingface')
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
131
132
133
134
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
138
139
140
141
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
142
143
144
145
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
            choices=['auto', 'fp8_e5m2'],
146
            default=EngineArgs.kv_cache_dtype,
147
148
149
            help='Data type for kv cache storage. If "auto", will use model '
            'data type. Note FP8 is not supported when cuda version is '
            'lower than 11.8.')
150
151
        parser.add_argument('--max-model-len',
                            type=int,
152
                            default=EngineArgs.max_model_len,
153
154
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
155
        # Parallel arguments
156
157
        parser.add_argument('--worker-use-ray',
                            action='store_true',
158
                            help='use Ray for distributed serving, will be '
159
160
161
162
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
163
                            default=EngineArgs.pipeline_parallel_size,
164
                            help='number of pipeline stages')
165
166
167
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
168
                            default=EngineArgs.tensor_parallel_size,
169
                            help='number of tensor parallel replicas')
170
171
172
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
173
            default=EngineArgs.max_parallel_loading_workers,
174
175
176
            help='load model sequentially in multiple batches, '
            'to avoid RAM OOM when using tensor '
            'parallel and large models')
177
178
179
180
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
            help='If specified, use nsight to profile ray workers')
181
        # KV cache arguments
182
183
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
184
                            default=EngineArgs.block_size,
185
                            choices=[8, 16, 32, 128],
186
                            help='token block size')
187
188
189
190
191

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
                            help='Enables automatic prefix caching')

192
193
194
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
195
                            help='random seed')
196
197
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
198
                            default=EngineArgs.swap_space,
199
                            help='CPU swap space size (GiB) per GPU')
200
201
202
203
204
205
206
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
            help='the fraction of GPU memory to be used for '
            'the model executor, which can range from 0 to 1.'
            'If unspecified, will use the default value of 0.9.')
207
208
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
209
                            default=EngineArgs.max_num_batched_tokens,
210
                            help='maximum number of batched tokens per '
211
212
213
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
214
                            default=EngineArgs.max_num_seqs,
215
                            help='maximum number of sequences per iteration')
216
217
218
219
220
221
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
            help=('max number of log probs to return logprobs is specified in'
                  ' SamplingParams'))
222
223
        parser.add_argument('--disable-log-stats',
                            action='store_true',
224
                            help='disable logging statistics')
225
226
227
228
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
CHU Tianxiang's avatar
CHU Tianxiang committed
229
                            choices=['awq', 'gptq', 'squeezellm', None],
230
                            default=EngineArgs.quantization,
231
232
233
234
235
236
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
237
238
239
240
241
242
243
244
245
246
247
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
                            help='maximum context length covered by CUDA '
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
248
249
250
251
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
                            help='See ParallelConfig')
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
                            type=str,
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
304
305
306
307
308
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=["auto", "cuda", "neuron"],
                            help='Device type for vLLM execution.')
309
310
311
312
313
314
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
315
        return parser
316
317

    @classmethod
318
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
319
320
321
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
322
323
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
324

Zhuohan Li's avatar
Zhuohan Li committed
325
    def create_engine_configs(
326
        self,
327
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
328
329
               DeviceConfig, Optional[LoRAConfig]]:
        device_config = DeviceConfig(self.device)
330
331
332
333
334
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
            self.trust_remote_code, self.download_dir, self.load_format,
            self.dtype, self.seed, self.revision, self.code_revision,
            self.tokenizer_revision, self.max_model_len, self.quantization,
335
336
            self.enforce_eager, self.max_context_len_to_capture,
            self.max_logprobs)
337
338
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
339
                                   self.swap_space, self.kv_cache_dtype,
340
341
342
343
344
345
346
347
348
349
                                   model_config.get_sliding_window())
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
350
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
351
                                           self.max_num_seqs,
352
353
                                           model_config.max_model_len,
                                           self.scheduler_delay_factor)
354
355
356
357
358
359
360
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
361
362
        return (model_config, cache_config, parallel_config, scheduler_config,
                device_config, lora_config)
363
364


365
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
366
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
367
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
368
    engine_use_ray: bool = False
369
    disable_log_requests: bool = False
370
    max_log_len: Optional[int] = None
371
372
373

    @staticmethod
    def add_cli_args(
374
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
375
        parser = EngineArgs.add_cli_args(parser)
376
377
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
378
                            help='use Ray to start the LLM engine in a '
379
380
381
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
382
                            help='disable logging requests')
383
384
385
386
387
388
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
389
        return parser