arg_utils.py 10 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
8
9


10
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
11
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    """Arguments for vLLM engine."""
13
    model: str
14
    tokenizer: Optional[str] = None
15
    tokenizer_mode: str = 'auto'
16
    trust_remote_code: bool = False
17
    download_dir: Optional[str] = None
18
    load_format: str = 'auto'
19
    dtype: str = 'auto'
20
    seed: int = 0
21
    max_model_len: Optional[int] = None
22
    worker_use_ray: bool = False
23
24
25
26
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
    block_size: int = 16
    swap_space: int = 4  # GiB
27
    gpu_memory_utilization: float = 0.90
28
    max_num_batched_tokens: Optional[int] = None
29
30
    max_num_seqs: int = 256
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
31
    revision: Optional[str] = None
32
    tokenizer_revision: Optional[str] = None
33
    quantization: Optional[str] = None
34

35
    def __post_init__(self):
36
37
        if self.tokenizer is None:
            self.tokenizer = self.model
38
39
40

    @staticmethod
    def add_cli_args(
41
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        """Shared CLI arguments for vLLM engine."""
43
        # Model arguments
44
45
46
47
48
49
50
51
52
53
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
54
55
56
57
58
59
60
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
61
62
63
64
65
66
67
        parser.add_argument(
            '--tokenizer-revision',
            type=str,
            default=None,
            help='the specific tokenizer version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
68
69
        parser.add_argument('--tokenizer-mode',
                            type=str,
70
71
72
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
73
74
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
75
76
77
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
78
79
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
80
                            default=EngineArgs.download_dir,
81
                            help='directory to download and load the weights, '
82
83
                            'default to the default cache dir of '
                            'huggingface')
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
99
100
101
102
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
106
107
108
109
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
110
111
112
113
114
        parser.add_argument('--max-model-len',
                            type=int,
                            default=None,
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
115
        # Parallel arguments
116
117
        parser.add_argument('--worker-use-ray',
                            action='store_true',
118
                            help='use Ray for distributed serving, will be '
119
120
121
122
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
123
                            default=EngineArgs.pipeline_parallel_size,
124
                            help='number of pipeline stages')
125
126
127
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
128
                            default=EngineArgs.tensor_parallel_size,
129
130
                            help='number of tensor parallel replicas')
        # KV cache arguments
131
132
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
133
                            default=EngineArgs.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
                            choices=[8, 16, 32],
135
136
                            help='token block size')
        # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
137
138
139
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
140
                            help='random seed')
141
142
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
143
                            default=EngineArgs.swap_space,
144
                            help='CPU swap space size (GiB) per GPU')
145
146
        parser.add_argument('--gpu-memory-utilization',
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
147
                            default=EngineArgs.gpu_memory_utilization,
148
                            help='the percentage of GPU memory to be used for'
149
150
151
                            'the model executor')
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
152
                            default=EngineArgs.max_num_batched_tokens,
153
                            help='maximum number of batched tokens per '
154
155
156
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
157
                            default=EngineArgs.max_num_seqs,
158
                            help='maximum number of sequences per iteration')
159
160
        parser.add_argument('--disable-log-stats',
                            action='store_true',
161
                            help='disable logging statistics')
162
163
164
165
166
167
168
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
                            choices=['awq', None],
                            default=None,
                            help='Method used to quantize the weights')
169
        return parser
170
171

    @classmethod
172
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
173
174
175
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
176
177
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
178

Zhuohan Li's avatar
Zhuohan Li committed
179
    def create_engine_configs(
180
181
        self,
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
182
        model_config = ModelConfig(self.model, self.tokenizer,
183
                                   self.tokenizer_mode, self.trust_remote_code,
184
                                   self.download_dir, self.load_format,
Jasmond L's avatar
Jasmond L committed
185
                                   self.dtype, self.seed, self.revision,
186
187
                                   self.tokenizer_revision, self.max_model_len,
                                   self.quantization)
188
189
190
        cache_config = CacheConfig(
            self.block_size, self.gpu_memory_utilization, self.swap_space,
            getattr(model_config.hf_config, 'sliding_window', None))
191
192
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
193
                                         self.worker_use_ray)
194
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
195
                                           self.max_num_seqs,
196
                                           model_config.max_model_len)
197
198
199
        return model_config, cache_config, parallel_config, scheduler_config


200
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
201
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
202
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
203
    engine_use_ray: bool = False
204
    disable_log_requests: bool = False
205
    max_log_len: Optional[int] = None
206
207
208

    @staticmethod
    def add_cli_args(
209
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
210
        parser = EngineArgs.add_cli_args(parser)
211
212
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
213
                            help='use Ray to start the LLM engine in a '
214
215
216
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
217
                            help='disable logging requests')
218
219
220
221
222
223
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
224
        return parser