arg_utils.py 5.49 KB
Newer Older
1
import argparse
2
3
4
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
5
6
7
8
9

from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
                              SchedulerConfig)


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@dataclass
class ServerArgs:
    model: str
    download_dir: Optional[str] = None
    use_np_weights: bool = False
    use_dummy_weights: bool = False
    dtype: str = "default"
    seed: int = 0
    use_ray: bool = False
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
    block_size: int = 16
    swap_space: int = 4  # GiB
    gpu_memory_utilization: float = 0.95
    max_num_batched_tokens: int = 2560
    max_num_seqs: int = 256
    disable_log_stats: bool = False
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def __post_init__(self):
        self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

    @staticmethod
    def add_cli_args(
        parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
        return _add_server_arguments(parser)

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
        server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return server_args

    def create_server_configs(
        self,
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
        # Initialize the configs.
        model_config = ModelConfig(
            self.model, self.download_dir, self.use_np_weights,
            self.use_dummy_weights, self.dtype, self.seed)
        cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
                                   self.swap_space)
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                         self.tensor_parallel_size,
                                         self.use_ray)
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
                                           self.max_num_seqs)
        return model_config, cache_config, parallel_config, scheduler_config


def _add_server_arguments(
    parser: argparse.ArgumentParser,
)-> argparse.ArgumentParser:
    """Shared CLI arguments for CacheFlow servers."""
66
    # Model arguments
67
68
69
70
    parser.add_argument('--model', type=str, default='facebook/opt-125m',
                        help='name or path of the huggingface model to use')
    parser.add_argument('--download-dir', type=str,
                        default=ServerArgs.download_dir,
71
72
73
                        help='directory to download and load the weights, '
                             'default to the default cache dir of huggingface')
    parser.add_argument('--use-np-weights', action='store_true',
74
75
76
77
78
                        help='save a numpy copy of model weights for faster '
                             'loading. This can increase the disk usage by up '
                             'to 2x.')
    parser.add_argument('--use-dummy-weights', action='store_true',
                        help='use dummy values for model weights')
79
    # TODO(woosuk): Support FP32.
80
81
    parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
                        choices=['default', 'half', 'bfloat16'],
82
83
84
85
86
                        help=('data type for model weights and activations. '
                              'The "default" option will use FP16 precision '
                              'for FP32 and FP16 models, and BF16 precision '
                              'for BF16 models.'))
    # Parallel arguments
87
88
89
90
91
92
93
94
95
    parser.add_argument('--use-ray', action='store_true',
                        help='use Ray for distributed serving, will be '
                             'automatically set when using more than 1 GPU')
    parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
                        default=ServerArgs.pipeline_parallel_size,
                        help='number of pipeline stages')
    parser.add_argument('--tensor-parallel-size', '-tp', type=int,
                        default=ServerArgs.tensor_parallel_size,
                        help='number of tensor parallel replicas')
96
    # KV cache arguments
97
98
99
    parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
                        choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                        help='token block size')
100
    # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    parser.add_argument('--seed', type=int, default=ServerArgs.seed,
                        help='random seed')
    parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
                        help='CPU swap space size (GiB) per GPU')
    parser.add_argument('--gpu-memory-utilization', type=float,
                        default=ServerArgs.gpu_memory_utilization,
                        help='the percentage of GPU memory to be used for the '
                             'model executor')
    parser.add_argument('--max-num-batched-tokens', type=int,
                        default=ServerArgs.max_num_batched_tokens,
                        help='maximum number of batched tokens per iteration')
    parser.add_argument('--max-num-seqs', type=int,
                        default=ServerArgs.max_num_seqs,
                        help='maximum number of sequences per iteration')
    parser.add_argument('--disable-log-stats', action='store_true',
                        help='disable logging statistics')
117
    return parser