arg_utils.py 6.16 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5
6
7
8
9

from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
                              SchedulerConfig)


10
11
12
13
14
15
@dataclass
class ServerArgs:
    model: str
    download_dir: Optional[str] = None
    use_np_weights: bool = False
    use_dummy_weights: bool = False
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    dtype: str = "auto"
17
    seed: int = 0
18
    worker_use_ray: bool = False
19
20
21
22
23
24
25
26
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
    block_size: int = 16
    swap_space: int = 4  # GiB
    gpu_memory_utilization: float = 0.95
    max_num_batched_tokens: int = 2560
    max_num_seqs: int = 256
    disable_log_stats: bool = False
27

28
29
30
31
32
33
34
    def __post_init__(self):
        self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

    @staticmethod
    def add_cli_args(
        parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        """Shared CLI arguments for CacheFlow servers."""
        # Model arguments
        parser.add_argument('--model', type=str, default='facebook/opt-125m',
                            help='name or path of the huggingface model to use')
        parser.add_argument('--download-dir', type=str,
                            default=ServerArgs.download_dir,
                            help='directory to download and load the weights, '
                                 'default to the default cache dir of '
                                 'huggingface')
        parser.add_argument('--use-np-weights', action='store_true',
                            help='save a numpy copy of model weights for '
                                 'faster loading. This can increase the disk '
                                 'usage by up to 2x.')
        parser.add_argument('--use-dummy-weights', action='store_true',
                            help='use dummy values for model weights')
        # TODO(woosuk): Support FP32.
        parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
52
                            choices=['auto', 'half', 'bfloat16', 'float'],
53
                            help='data type for model weights and activations. '
Woosuk Kwon's avatar
Woosuk Kwon committed
54
                                 'The "auto" option will use FP16 precision '
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
                                 'for FP32 and FP16 models, and BF16 precision '
                                 'for BF16 models.')
        # Parallel arguments
        parser.add_argument('--worker-use-ray', action='store_true',
                            help='use Ray for distributed serving, will be '
                                 'automatically set when using more than 1 GPU')
        parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
                            default=ServerArgs.pipeline_parallel_size,
                            help='number of pipeline stages')
        parser.add_argument('--tensor-parallel-size', '-tp', type=int,
                            default=ServerArgs.tensor_parallel_size,
                            help='number of tensor parallel replicas')
        # KV cache arguments
        parser.add_argument('--block-size', type=int,
                            default=ServerArgs.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
70
                            choices=[8, 16, 32],
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
                            help='token block size')
        # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
        parser.add_argument('--seed', type=int, default=ServerArgs.seed,
                            help='random seed')
        parser.add_argument('--swap-space', type=int,
                            default=ServerArgs.swap_space,
                            help='CPU swap space size (GiB) per GPU')
        parser.add_argument('--gpu-memory-utilization', type=float,
                            default=ServerArgs.gpu_memory_utilization,
                            help='the percentage of GPU memory to be used for'
                                 'the model executor')
        parser.add_argument('--max-num-batched-tokens', type=int,
                            default=ServerArgs.max_num_batched_tokens,
                            help='maximum number of batched tokens per '
                                 'iteration')
        parser.add_argument('--max-num-seqs', type=int,
                            default=ServerArgs.max_num_seqs,
                            help='maximum number of sequences per iteration')
        parser.add_argument('--disable-log-stats', action='store_true',
                            help='disable logging statistics')
        return parser
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
        server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return server_args

    def create_server_configs(
        self,
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
        # Initialize the configs.
        model_config = ModelConfig(
            self.model, self.download_dir, self.use_np_weights,
            self.use_dummy_weights, self.dtype, self.seed)
        cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
                                   self.swap_space)
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
112
                                         self.worker_use_ray)
113
114
115
116
117
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
                                           self.max_num_seqs)
        return model_config, cache_config, parallel_config, scheduler_config


118
119
120
121
122
123
124
125
126
127
128
129
130
@dataclass
class AsyncServerArgs(ServerArgs):
    server_use_ray: bool = False

    @staticmethod
    def add_cli_args(
        parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
        parser = ServerArgs.add_cli_args(parser)
        parser.add_argument('--server-use-ray', action='store_true',
                            help='use Ray to start the LLM server in a '
                                 'separate process as the web server process.')
        return parser