arg_utils.py 7.12 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
8
9


10
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
11
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    """Arguments for vLLM engine."""
13
    model: str
14
    tokenizer: Optional[str] = None
15
    tokenizer_mode: str = "auto"
16
17
18
    download_dir: Optional[str] = None
    use_np_weights: bool = False
    use_dummy_weights: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
19
    dtype: str = "auto"
20
    seed: int = 0
21
    worker_use_ray: bool = False
22
23
24
25
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
    block_size: int = 16
    swap_space: int = 4  # GiB
26
    gpu_memory_utilization: float = 0.90
27
28
29
    max_num_batched_tokens: int = 2560
    max_num_seqs: int = 256
    disable_log_stats: bool = False
30

31
    def __post_init__(self):
32
33
        if self.tokenizer is None:
            self.tokenizer = self.model
34
35
36
37
38
39
        self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

    @staticmethod
    def add_cli_args(
        parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
40
        """Shared CLI arguments for vLLM engine."""
41
42
43
        # Model arguments
        parser.add_argument('--model', type=str, default='facebook/opt-125m',
                            help='name or path of the huggingface model to use')
44
45
        parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer,
                            help='name or path of the huggingface tokenizer to use')
46
47
48
49
50
51
        parser.add_argument('--tokenizer-mode', type=str,
                            default=EngineArgs.tokenizer_mode,
                            choices=['auto', 'slow'],
                            help='tokenizer mode. "auto" will use the fast '
                                 'tokenizer if available, and "slow" will '
                                 'always use the slow tokenizer.')
52
        parser.add_argument('--download-dir', type=str,
Zhuohan Li's avatar
Zhuohan Li committed
53
                            default=EngineArgs.download_dir,
54
55
56
57
58
59
60
61
62
63
                            help='directory to download and load the weights, '
                                 'default to the default cache dir of '
                                 'huggingface')
        parser.add_argument('--use-np-weights', action='store_true',
                            help='save a numpy copy of model weights for '
                                 'faster loading. This can increase the disk '
                                 'usage by up to 2x.')
        parser.add_argument('--use-dummy-weights', action='store_true',
                            help='use dummy values for model weights')
        # TODO(woosuk): Support FP32.
Zhuohan Li's avatar
Zhuohan Li committed
64
        parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
65
                            choices=['auto', 'half', 'bfloat16', 'float'],
66
                            help='data type for model weights and activations. '
Woosuk Kwon's avatar
Woosuk Kwon committed
67
                                 'The "auto" option will use FP16 precision '
68
69
70
71
72
73
74
                                 'for FP32 and FP16 models, and BF16 precision '
                                 'for BF16 models.')
        # Parallel arguments
        parser.add_argument('--worker-use-ray', action='store_true',
                            help='use Ray for distributed serving, will be '
                                 'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
Zhuohan Li's avatar
Zhuohan Li committed
75
                            default=EngineArgs.pipeline_parallel_size,
76
77
                            help='number of pipeline stages')
        parser.add_argument('--tensor-parallel-size', '-tp', type=int,
Zhuohan Li's avatar
Zhuohan Li committed
78
                            default=EngineArgs.tensor_parallel_size,
79
80
81
                            help='number of tensor parallel replicas')
        # KV cache arguments
        parser.add_argument('--block-size', type=int,
Zhuohan Li's avatar
Zhuohan Li committed
82
                            default=EngineArgs.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
                            choices=[8, 16, 32],
84
85
                            help='token block size')
        # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
Zhuohan Li's avatar
Zhuohan Li committed
86
        parser.add_argument('--seed', type=int, default=EngineArgs.seed,
87
88
                            help='random seed')
        parser.add_argument('--swap-space', type=int,
Zhuohan Li's avatar
Zhuohan Li committed
89
                            default=EngineArgs.swap_space,
90
91
                            help='CPU swap space size (GiB) per GPU')
        parser.add_argument('--gpu-memory-utilization', type=float,
Zhuohan Li's avatar
Zhuohan Li committed
92
                            default=EngineArgs.gpu_memory_utilization,
93
94
95
                            help='the percentage of GPU memory to be used for'
                                 'the model executor')
        parser.add_argument('--max-num-batched-tokens', type=int,
Zhuohan Li's avatar
Zhuohan Li committed
96
                            default=EngineArgs.max_num_batched_tokens,
97
98
99
                            help='maximum number of batched tokens per '
                                 'iteration')
        parser.add_argument('--max-num-seqs', type=int,
Zhuohan Li's avatar
Zhuohan Li committed
100
                            default=EngineArgs.max_num_seqs,
101
102
103
104
                            help='maximum number of sequences per iteration')
        parser.add_argument('--disable-log-stats', action='store_true',
                            help='disable logging statistics')
        return parser
105
106

    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
107
    def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
108
109
110
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
111
112
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
113

Zhuohan Li's avatar
Zhuohan Li committed
114
    def create_engine_configs(
115
116
117
118
        self,
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
        # Initialize the configs.
        model_config = ModelConfig(
119
120
            self.model, self.tokenizer, self.tokenizer_mode, self.download_dir,
            self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed)
121
122
123
124
        cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
                                   self.swap_space)
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
125
                                         self.worker_use_ray)
126
127
128
129
130
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
                                           self.max_num_seqs)
        return model_config, cache_config, parallel_config, scheduler_config


131
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
132
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
133
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
134
    engine_use_ray: bool = False
135
    disable_log_requests: bool = False
136
137
138
139
140

    @staticmethod
    def add_cli_args(
        parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
Zhuohan Li's avatar
Zhuohan Li committed
141
142
143
144
        parser = EngineArgs.add_cli_args(parser)
        parser.add_argument('--engine-use-ray', action='store_true',
                            help='use Ray to start the LLM engine in a '
                                 'separate process as the server process.')
145
146
        parser.add_argument('--disable-log-requests', action='store_true',
                            help='disable logging requests')
147
        return parser