arg_utils.py 14.2 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
7
                         SchedulerConfig, LoRAConfig)
8
9


10
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
11
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    """Arguments for vLLM engine."""
13
    model: str
14
    tokenizer: Optional[str] = None
15
    tokenizer_mode: str = 'auto'
16
    trust_remote_code: bool = False
17
    download_dir: Optional[str] = None
18
    load_format: str = 'auto'
19
    dtype: str = 'auto'
20
    seed: int = 0
21
    max_model_len: Optional[int] = None
22
    worker_use_ray: bool = False
23
24
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
25
    max_parallel_loading_workers: Optional[int] = None
26
27
    block_size: int = 16
    swap_space: int = 4  # GiB
28
    gpu_memory_utilization: float = 0.90
29
    max_num_batched_tokens: Optional[int] = None
30
    max_num_seqs: int = 256
31
    max_paddings: int = 256
32
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
33
    revision: Optional[str] = None
34
    tokenizer_revision: Optional[str] = None
35
    quantization: Optional[str] = None
36
37
    enforce_eager: bool = False
    max_context_len_to_capture: int = 8192
38
39
40
41
42
43
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
44

45
    def __post_init__(self):
46
47
        if self.tokenizer is None:
            self.tokenizer = self.model
48
49
50

    @staticmethod
    def add_cli_args(
51
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
52
        """Shared CLI arguments for vLLM engine."""
53
54
55
56

        # NOTE: If you update any of the arguments below, please also
        # make sure to update docs/source/models/engine_args.rst

57
        # Model arguments
58
59
60
61
62
63
64
65
66
67
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
68
69
70
71
72
73
74
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
75
76
77
78
79
80
81
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
82
83
        parser.add_argument('--tokenizer-mode',
                            type=str,
84
85
86
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
87
88
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
89
90
91
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
92
93
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
94
                            default=EngineArgs.download_dir,
95
                            help='directory to download and load the weights, '
96
97
                            'default to the default cache dir of '
                            'huggingface')
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
113
114
115
116
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
120
121
122
123
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
124
125
126
127
128
        parser.add_argument('--max-model-len',
                            type=int,
                            default=None,
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
129
        # Parallel arguments
130
131
        parser.add_argument('--worker-use-ray',
                            action='store_true',
132
                            help='use Ray for distributed serving, will be '
133
134
135
136
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
137
                            default=EngineArgs.pipeline_parallel_size,
138
                            help='number of pipeline stages')
139
140
141
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
142
                            default=EngineArgs.tensor_parallel_size,
143
                            help='number of tensor parallel replicas')
144
145
146
147
148
149
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
            help='load model sequentially in multiple batches, '
            'to avoid RAM OOM when using tensor '
            'parallel and large models')
150
        # KV cache arguments
151
152
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
153
                            default=EngineArgs.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
154
                            choices=[8, 16, 32],
155
156
                            help='token block size')
        # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
157
158
159
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
160
                            help='random seed')
161
162
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
163
                            default=EngineArgs.swap_space,
164
                            help='CPU swap space size (GiB) per GPU')
165
166
167
168
169
170
171
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
            help='the fraction of GPU memory to be used for '
            'the model executor, which can range from 0 to 1.'
            'If unspecified, will use the default value of 0.9.')
172
173
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
174
                            default=EngineArgs.max_num_batched_tokens,
175
                            help='maximum number of batched tokens per '
176
177
178
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
179
                            default=EngineArgs.max_num_seqs,
180
                            help='maximum number of sequences per iteration')
181
182
183
184
        parser.add_argument('--max-paddings',
                            type=int,
                            default=EngineArgs.max_paddings,
                            help='maximum number of paddings in a batch')
185
186
        parser.add_argument('--disable-log-stats',
                            action='store_true',
187
                            help='disable logging statistics')
188
189
190
191
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
CHU Tianxiang's avatar
CHU Tianxiang committed
192
                            choices=['awq', 'gptq', 'squeezellm', None],
193
                            default=None,
194
195
196
197
198
199
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
200
201
202
203
204
205
206
207
208
209
210
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
                            help='maximum context length covered by CUDA '
                            'graphs. When a sequence has context length '
                            'larger than this, we fall back to eager mode.')
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
244
        return parser
245
246

    @classmethod
247
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
248
249
250
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
251
252
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
253

Zhuohan Li's avatar
Zhuohan Li committed
254
    def create_engine_configs(
255
        self,
256
257
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
               Optional[LoRAConfig]]:
258
        model_config = ModelConfig(self.model, self.tokenizer,
259
                                   self.tokenizer_mode, self.trust_remote_code,
260
                                   self.download_dir, self.load_format,
Jasmond L's avatar
Jasmond L committed
261
                                   self.dtype, self.seed, self.revision,
262
                                   self.tokenizer_revision, self.max_model_len,
263
264
                                   self.quantization, self.enforce_eager,
                                   self.max_context_len_to_capture)
265
266
267
268
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
                                   self.swap_space,
                                   model_config.get_sliding_window())
269
270
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
271
272
                                         self.worker_use_ray,
                                         self.max_parallel_loading_workers)
273
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
274
                                           self.max_num_seqs,
275
276
                                           model_config.max_model_len,
                                           self.max_paddings)
277
278
279
280
281
282
283
284
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
        return model_config, cache_config, parallel_config, scheduler_config, lora_config
285
286


287
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
288
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
289
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
290
    engine_use_ray: bool = False
291
    disable_log_requests: bool = False
292
    max_log_len: Optional[int] = None
293
294
295

    @staticmethod
    def add_cli_args(
296
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
297
        parser = EngineArgs.add_cli_args(parser)
298
299
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
300
                            help='use Ray to start the LLM engine in a '
301
302
303
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
304
                            help='disable logging requests')
305
306
307
308
309
310
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
311
        return parser