server.py 10.9 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import argparse
2
from typing import List, Tuple, Optional
Zhuohan Li's avatar
Zhuohan Li committed
3
4
import random

5
6
7
8
9
import torch
try:
    import ray
except ImportError:
    ray = None
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11

from cacheflow.master.scheduler import Scheduler
12
from cacheflow.master.simple_frontend import SimpleFrontend
13
from cacheflow.models import get_memory_analyzer
Zhuohan Li's avatar
Zhuohan Li committed
14
from cacheflow.worker.controller import Controller, DeviceID
15
16
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
17
from cacheflow.utils import get_gpu_memory, get_cpu_memory
18

19

20
21
22
23
class Server:
    def __init__(
        self,
        model: str,
24
        cache_dir: Optional[str],
25
        use_dummy_weights: bool,
26
        use_np_cache: bool,
27
28
29
30
31
32
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
        block_size: int,
        dtype: str,
        seed: int,
        swap_space: int,
33
        max_num_batched_tokens: int,
34
        max_num_sequences: int,
35
36
37
38
39
40
        num_nodes: int,
        num_devices_per_node: int,
        distributed_init_method: str,
        all_stage_devices: List[List[DeviceID]],
        gpu_memory: int,
        cpu_memory: int,
41
        use_ray: bool,
42
43
        collect_stats: bool = False,
        do_memory_analysis: bool = False,
44
45
46
47
48
    ):
        self.num_nodes = num_nodes
        self.num_devices_per_node = num_devices_per_node
        self.world_size = pipeline_parallel_size * tensor_parallel_size

49
50
51
52
        if not use_ray:
            assert self.world_size == 1, (
                "Only support single GPU without Ray.")

53
54
55
56
57
58
59
60
61
        self.memory_analyzer = get_memory_analyzer(
            model_name=model,
            block_size=block_size,
            dtype=dtype,
            gpu_memory=gpu_memory,
            cpu_memory=cpu_memory,
            tensor_parallel_size=tensor_parallel_size,
        )
        self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
62
            max_num_batched_tokens=max_num_batched_tokens)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
            swap_space=swap_space)
        print(f'# GPU blocks: {self.num_gpu_blocks}, '
              f'# CPU blocks: {self.num_cpu_blocks}')

        # Create a controller for each pipeline stage.
        self.controllers: List[Controller] = []
        for i in range(pipeline_parallel_size):
            controller = Controller(
                stage_id=i,
                stage_devices=all_stage_devices[i],
                world_size=self.world_size,
                pipeline_parallel_size=pipeline_parallel_size,
                tensor_parallel_size=tensor_parallel_size,
                distributed_init_method=distributed_init_method,
                model_name=model,
                block_size=block_size,
                num_gpu_blocks=self.num_gpu_blocks,
                num_cpu_blocks=self.num_cpu_blocks,
                dtype=dtype,
                seed=seed,
84
                cache_dir=cache_dir,
85
                use_dummy_weights=use_dummy_weights,
86
                use_np_cache=use_np_cache,
87
                max_num_batched_tokens=max_num_batched_tokens,
88
                use_ray=use_ray,
89
90
91
92
93
94
95
96
97
            )
            self.controllers.append(controller)

        # Create a scheduler.
        self.scheduler = Scheduler(
            controllers=self.controllers,
            block_size=block_size,
            num_gpu_blocks=self.num_gpu_blocks,
            num_cpu_blocks=self.num_cpu_blocks,
98
            max_num_batched_tokens=max_num_batched_tokens,
99
100
101
            max_num_sequences=max_num_sequences,
            collect_stats=collect_stats,
            do_memory_analysis=do_memory_analysis,
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        )
        # Connect the controllers.
        for i in range(len(self.controllers) - 1):
            self.controllers[i].set_next(self.controllers[i + 1])
        self.controllers[-1].set_next(self.scheduler)

    def add_sequence_groups(
        self,
        sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
    ):
        self.scheduler.add_sequence_groups(sequence_groups)

    def step(self):
        return self.scheduler.step()

    def has_unfinished_requests(self):
118
        return (self.scheduler.waiting or self.scheduler.running or
119
                self.scheduler.swapped)
Zhuohan Li's avatar
Zhuohan Li committed
120
121


122
123
124
def initialize_cluster(
    use_ray: bool = False,
    address: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
125
126
127
    pipeline_parallel_size: int = 1,
    tensor_parallel_size: int = 1,
) -> Tuple[int, int, str, List[List[DeviceID]]]:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    # Initialize cluster locally.
    if not use_ray:
        assert pipeline_parallel_size * tensor_parallel_size == 1, (
            "Only support single GPU without Ray.")
        num_nodes = 1
        num_devices_per_node = torch.cuda.device_count()
        port = random.randint(10000, 20000)
        # We need to setup the distributed init method to make sure
        # the distributed megatron code (e.g., get world size) works correctly.
        distributed_init_method = f"tcp://localhost:{port}"
        all_stage_devices = [[(0, None, 0)]]
        return (num_nodes, num_devices_per_node, distributed_init_method,
                all_stage_devices)

    assert ray is not None, (
        "Ray is not installed. Please install Ray to use distributed "
        "serving.")

Zhuohan Li's avatar
Zhuohan Li committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    # Connect to a ray cluster.
    ray.init(address=address)

    # Assume we have a uniform cluster that each node has the same number of
    # GPUs for now.
    valid_node_resources = []
    num_devices_per_node = None
    for node in ray.nodes():
        if (not node['Alive']) or node['Resources']['GPU'] <= 0:
            continue
        if num_devices_per_node is None:
            num_devices_per_node = node['Resources']['GPU']
        else:
            assert num_devices_per_node == node['Resources']['GPU'], (
                "The number of GPUs per node is not uniform.")
        for key in node['Resources']:
            if key.startswith('node:'):
                valid_node_resources.append(key)

    num_nodes = len(valid_node_resources)

    assert (pipeline_parallel_size * tensor_parallel_size
            <= num_nodes * num_devices_per_node), (
                "The number of required GPUs exceeds the total number of "
                "available GPUs.")
    if tensor_parallel_size >= num_devices_per_node:
        assert tensor_parallel_size % num_devices_per_node == 0, (
            "The number of tensor parallelism is not divisible by the "
            "number of GPUs per node.")
    else:
        assert num_devices_per_node % tensor_parallel_size == 0, (
            "The number of GPUs per node is not divisible by the number "
            "of tensor parallelism.")

    # Assign GPUs to pipeline stages.
    rank = 0
    current_node_id = 0
    current_device_id = 0
    distributed_init_method = None
    all_stage_devices = []

    for i in range(pipeline_parallel_size):
        stage_devices = []
        for j in range(tensor_parallel_size):
            node_resource = valid_node_resources[current_node_id]
            stage_devices.append((rank, node_resource, current_device_id))
            if distributed_init_method is None:
                ip = node_resource.split("node:")[-1]
                port = random.randint(10000, 20000)
                distributed_init_method = f"tcp://{ip}:{port}"
            rank += 1
            current_device_id += 1
            if current_device_id >= num_devices_per_node:
                current_node_id += 1
                current_device_id = 0
        all_stage_devices.append(stage_devices)

    return (num_nodes, num_devices_per_node, distributed_init_method,
            all_stage_devices)


207
def add_server_arguments(parser: argparse.ArgumentParser):
Zhuohan Li's avatar
Zhuohan Li committed
208
209
    # Model arguments
    parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
210
211
212
213
214
215
216
217
    parser.add_argument('--cache-dir', type=str, default=None,
                        help='cache dir to download and load the weights, '
                             'default to the default cache dir of huggingface')
    parser.add_argument('--use-np-cache', action='store_true',
                        help='save a numpy copy of model weights for faster loading')
    parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
    # NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
    parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
Zhuohan Li's avatar
Zhuohan Li committed
218
    # Parallel arguments
219
    parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
220
221
    parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
Zhuohan Li's avatar
Zhuohan Li committed
222
    # KV cache arguments
223
    parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
Zhuohan Li's avatar
Zhuohan Li committed
224
225
226
    # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
227
228
    parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
    parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
229
    return parser
230

231

232
233
234
235
def process_server_arguments(args: argparse.Namespace):
    if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
        args.use_ray = True
    return args
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278


def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
    # TODO(zhuohan): Support pipeline parallelism.
    assert args.pipeline_parallel_size == 1, (
        'Pipeline parallelism is not supported yet.')

    (num_nodes, num_devices_per_node, distributed_init_method,
    all_stage_devices) = (
        initialize_cluster(
            use_ray=args.use_ray,
            pipeline_parallel_size=args.pipeline_parallel_size,
            tensor_parallel_size=args.tensor_parallel_size))

    # Create a server.
    server = Server(
        model=args.model,
        cache_dir=args.cache_dir,
        use_dummy_weights=args.use_dummy_weights,
        use_np_cache=args.use_np_cache,
        pipeline_parallel_size=args.pipeline_parallel_size,
        tensor_parallel_size=args.tensor_parallel_size,
        block_size=args.block_size,
        dtype=args.dtype,
        seed=args.seed,
        swap_space=args.swap_space,
        max_num_batched_tokens=args.max_num_batched_tokens,
        max_num_sequences=args.max_num_sequences,
        num_nodes=num_nodes,
        num_devices_per_node=num_devices_per_node,
        distributed_init_method=distributed_init_method,
        all_stage_devices=all_stage_devices,
        gpu_memory=get_gpu_memory(),
        cpu_memory=get_cpu_memory(),
        use_ray=args.use_ray,
    )

    # Create a frontend.
    frontend = SimpleFrontend(
        model_name=args.model,
        block_size=args.block_size,
    )
    return server, frontend