"docs/source/en/model_doc/clip.md" did not exist on "9643ecf8ca15b088784512e3329b9a7d6a02931d"
arg_utils.py 9.67 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
8
9


10
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
11
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    """Arguments for vLLM engine."""
13
    model: str
14
    tokenizer: Optional[str] = None
15
    tokenizer_mode: str = 'auto'
16
    trust_remote_code: bool = False
17
    download_dir: Optional[str] = None
18
    load_format: str = 'auto'
19
    dtype: str = 'auto'
20
    seed: int = 0
21
    max_model_len: Optional[int] = None
22
    worker_use_ray: bool = False
23
24
25
26
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
    block_size: int = 16
    swap_space: int = 4  # GiB
27
    gpu_memory_utilization: float = 0.90
28
29
30
    max_num_batched_tokens: int = 2560
    max_num_seqs: int = 256
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
31
    revision: Optional[str] = None
32
    quantization: Optional[str] = None
33

34
    def __post_init__(self):
35
36
        if self.tokenizer is None:
            self.tokenizer = self.model
37
38
39
40
        self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

    @staticmethod
    def add_cli_args(
41
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        """Shared CLI arguments for vLLM engine."""
43
        # Model arguments
44
45
46
47
48
49
50
51
52
53
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
Jasmond L's avatar
Jasmond L committed
54
55
56
57
58
59
60
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='the specific model version to use. It can be a branch '
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
61
62
        parser.add_argument('--tokenizer-mode',
                            type=str,
63
64
65
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
66
67
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
68
69
70
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
71
72
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
73
                            default=EngineArgs.download_dir,
74
                            help='directory to download and load the weights, '
75
76
                            'default to the default cache dir of '
                            'huggingface')
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
            help='The format of the model weights to load. '
            '"auto" will try to load the weights in the safetensors format '
            'and fall back to the pytorch bin format if safetensors format '
            'is not available. '
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading. '
            '"dummy" will initialize the weights with random values, '
            'which is mainly for profiling.')
92
93
94
95
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
99
100
101
102
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
103
104
105
106
107
        parser.add_argument('--max-model-len',
                            type=int,
                            default=None,
                            help='model context length. If unspecified, '
                            'will be automatically derived from the model.')
108
        # Parallel arguments
109
110
        parser.add_argument('--worker-use-ray',
                            action='store_true',
111
                            help='use Ray for distributed serving, will be '
112
113
114
115
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
116
                            default=EngineArgs.pipeline_parallel_size,
117
                            help='number of pipeline stages')
118
119
120
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
121
                            default=EngineArgs.tensor_parallel_size,
122
123
                            help='number of tensor parallel replicas')
        # KV cache arguments
124
125
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
126
                            default=EngineArgs.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
127
                            choices=[8, 16, 32],
128
129
                            help='token block size')
        # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
130
131
132
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
133
                            help='random seed')
134
135
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
136
                            default=EngineArgs.swap_space,
137
                            help='CPU swap space size (GiB) per GPU')
138
139
        parser.add_argument('--gpu-memory-utilization',
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
140
                            default=EngineArgs.gpu_memory_utilization,
141
                            help='the percentage of GPU memory to be used for'
142
143
144
                            'the model executor')
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
145
                            default=EngineArgs.max_num_batched_tokens,
146
                            help='maximum number of batched tokens per '
147
148
149
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
150
                            default=EngineArgs.max_num_seqs,
151
                            help='maximum number of sequences per iteration')
152
153
        parser.add_argument('--disable-log-stats',
                            action='store_true',
154
                            help='disable logging statistics')
155
156
157
158
159
160
161
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
                            type=str,
                            choices=['awq', None],
                            default=None,
                            help='Method used to quantize the weights')
162
        return parser
163
164

    @classmethod
165
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
166
167
168
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
169
170
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
171

Zhuohan Li's avatar
Zhuohan Li committed
172
    def create_engine_configs(
173
174
        self,
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
175
        model_config = ModelConfig(self.model, self.tokenizer,
176
                                   self.tokenizer_mode, self.trust_remote_code,
177
                                   self.download_dir, self.load_format,
Jasmond L's avatar
Jasmond L committed
178
                                   self.dtype, self.seed, self.revision,
179
                                   self.max_model_len, self.quantization)
180
181
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
182
183
184
                                   self.swap_space)
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
185
                                         self.worker_use_ray)
186
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
187
188
                                           self.max_num_seqs,
                                           model_config.get_max_model_len())
189
190
191
        return model_config, cache_config, parallel_config, scheduler_config


192
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
193
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
194
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
195
    engine_use_ray: bool = False
196
    disable_log_requests: bool = False
197
    max_log_len: Optional[int] = None
198
199
200

    @staticmethod
    def add_cli_args(
201
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
202
        parser = EngineArgs.add_cli_args(parser)
203
204
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
205
                            help='use Ray to start the LLM engine in a '
206
207
208
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
209
                            help='disable logging requests')
210
211
212
213
214
215
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
                            help='max number of prompt characters or prompt '
                            'ID numbers being printed in log. '
                            'Default: unlimited.')
216
        return parser