arg_utils.py 8.02 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
8
9


10
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
11
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    """Arguments for vLLM engine."""
13
    model: str
14
    tokenizer: Optional[str] = None
15
    tokenizer_mode: str = 'auto'
16
    trust_remote_code: bool = False
17
18
19
    download_dir: Optional[str] = None
    use_np_weights: bool = False
    use_dummy_weights: bool = False
20
    dtype: str = 'auto'
21
    seed: int = 0
22
    worker_use_ray: bool = False
23
24
25
26
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
    block_size: int = 16
    swap_space: int = 4  # GiB
27
    gpu_memory_utilization: float = 0.90
28
29
30
    max_num_batched_tokens: int = 2560
    max_num_seqs: int = 256
    disable_log_stats: bool = False
31

32
    def __post_init__(self):
33
34
        if self.tokenizer is None:
            self.tokenizer = self.model
35
36
37
38
        self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

    @staticmethod
    def add_cli_args(
39
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
40
        """Shared CLI arguments for vLLM engine."""
41
        # Model arguments
42
43
44
45
46
47
48
49
50
51
52
53
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
            help='name or path of the huggingface model to use')
        parser.add_argument(
            '--tokenizer',
            type=str,
            default=EngineArgs.tokenizer,
            help='name or path of the huggingface tokenizer to use')
        parser.add_argument('--tokenizer-mode',
                            type=str,
54
55
56
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
57
58
                            'tokenizer if available, and "slow" will '
                            'always use the slow tokenizer.')
59
60
61
        parser.add_argument('--trust-remote-code',
                            action='store_true',
                            help='trust remote code from huggingface')
62
63
        parser.add_argument('--download-dir',
                            type=str,
Zhuohan Li's avatar
Zhuohan Li committed
64
                            default=EngineArgs.download_dir,
65
                            help='directory to download and load the weights, '
66
67
68
69
                            'default to the default cache dir of '
                            'huggingface')
        parser.add_argument('--use-np-weights',
                            action='store_true',
70
                            help='save a numpy copy of model weights for '
71
72
73
74
                            'faster loading. This can increase the disk '
                            'usage by up to 2x.')
        parser.add_argument('--use-dummy-weights',
                            action='store_true',
75
76
                            help='use dummy values for model weights')
        # TODO(woosuk): Support FP32.
77
78
79
80
81
82
83
84
85
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
            choices=['auto', 'half', 'bfloat16', 'float'],
            help='data type for model weights and activations. '
            'The "auto" option will use FP16 precision '
            'for FP32 and FP16 models, and BF16 precision '
            'for BF16 models.')
86
        # Parallel arguments
87
88
        parser.add_argument('--worker-use-ray',
                            action='store_true',
89
                            help='use Ray for distributed serving, will be '
90
91
92
93
                            'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
94
                            default=EngineArgs.pipeline_parallel_size,
95
                            help='number of pipeline stages')
96
97
98
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
99
                            default=EngineArgs.tensor_parallel_size,
100
101
                            help='number of tensor parallel replicas')
        # KV cache arguments
102
103
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
104
                            default=EngineArgs.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
                            choices=[8, 16, 32],
106
107
                            help='token block size')
        # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
108
109
110
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
111
                            help='random seed')
112
113
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
114
                            default=EngineArgs.swap_space,
115
                            help='CPU swap space size (GiB) per GPU')
116
117
        parser.add_argument('--gpu-memory-utilization',
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
118
                            default=EngineArgs.gpu_memory_utilization,
119
                            help='the percentage of GPU memory to be used for'
120
121
122
                            'the model executor')
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
123
                            default=EngineArgs.max_num_batched_tokens,
124
                            help='maximum number of batched tokens per '
125
126
127
                            'iteration')
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
128
                            default=EngineArgs.max_num_seqs,
129
                            help='maximum number of sequences per iteration')
130
131
        parser.add_argument('--disable-log-stats',
                            action='store_true',
132
133
                            help='disable logging statistics')
        return parser
134
135

    @classmethod
136
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
137
138
139
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
140
141
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
142

Zhuohan Li's avatar
Zhuohan Li committed
143
    def create_engine_configs(
144
145
146
        self,
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
        # Initialize the configs.
147
        model_config = ModelConfig(self.model, self.tokenizer,
148
149
150
151
                                   self.tokenizer_mode, self.trust_remote_code,
                                   self.download_dir, self.use_np_weights,
                                   self.use_dummy_weights, self.dtype,
                                   self.seed)
152
153
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
154
155
156
                                   self.swap_space)
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
157
                                         self.worker_use_ray)
158
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
159
160
                                           self.max_num_seqs,
                                           model_config.get_max_model_len())
161
162
163
        return model_config, cache_config, parallel_config, scheduler_config


164
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
165
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
166
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
167
    engine_use_ray: bool = False
168
    disable_log_requests: bool = False
169
170
171

    @staticmethod
    def add_cli_args(
172
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
173
        parser = EngineArgs.add_cli_args(parser)
174
175
        parser.add_argument('--engine-use-ray',
                            action='store_true',
Zhuohan Li's avatar
Zhuohan Li committed
176
                            help='use Ray to start the LLM engine in a '
177
178
179
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
180
                            help='disable logging requests')
181
        return parser