server.py 6.93 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import argparse
Zhuohan Li's avatar
Zhuohan Li committed
2
3
4
5
import random
from typing import List, Tuple, Dict

import ray
Woosuk Kwon's avatar
Woosuk Kwon committed
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from cacheflow.master.frontend import Frontend
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from cacheflow.master.scheduler import Scheduler
9
from cacheflow.models import get_memory_analyzer
Zhuohan Li's avatar
Zhuohan Li committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from cacheflow.worker.controller import Controller, DeviceID


def initialize_ray_cluster(
    address: str = 'auto',
    pipeline_parallel_size: int = 1,
    tensor_parallel_size: int = 1,
) -> Tuple[int, int, str, List[List[DeviceID]]]:
    # Connect to a ray cluster.
    ray.init(address=address)

    # Assume we have a uniform cluster that each node has the same number of
    # GPUs for now.
    valid_node_resources = []
    num_devices_per_node = None
    for node in ray.nodes():
        if (not node['Alive']) or node['Resources']['GPU'] <= 0:
            continue
        if num_devices_per_node is None:
            num_devices_per_node = node['Resources']['GPU']
        else:
            assert num_devices_per_node == node['Resources']['GPU'], (
                "The number of GPUs per node is not uniform.")
        for key in node['Resources']:
            if key.startswith('node:'):
                valid_node_resources.append(key)

    num_nodes = len(valid_node_resources)

    assert (pipeline_parallel_size * tensor_parallel_size
            <= num_nodes * num_devices_per_node), (
                "The number of required GPUs exceeds the total number of "
                "available GPUs.")
    if tensor_parallel_size >= num_devices_per_node:
        assert tensor_parallel_size % num_devices_per_node == 0, (
            "The number of tensor parallelism is not divisible by the "
            "number of GPUs per node.")
    else:
        assert num_devices_per_node % tensor_parallel_size == 0, (
            "The number of GPUs per node is not divisible by the number "
            "of tensor parallelism.")

    # Assign GPUs to pipeline stages.
    rank = 0
    current_node_id = 0
    current_device_id = 0
    distributed_init_method = None
    all_stage_devices = []

    for i in range(pipeline_parallel_size):
        stage_devices = []
        for j in range(tensor_parallel_size):
            node_resource = valid_node_resources[current_node_id]
            stage_devices.append((rank, node_resource, current_device_id))
            if distributed_init_method is None:
                ip = node_resource.split("node:")[-1]
                port = random.randint(10000, 20000)
                distributed_init_method = f"tcp://{ip}:{port}"
            rank += 1
            current_device_id += 1
            if current_device_id >= num_devices_per_node:
                current_node_id += 1
                current_device_id = 0
        all_stage_devices.append(stage_devices)

    return (num_nodes, num_devices_per_node, distributed_init_method,
            all_stage_devices)


def main(args: argparse.Namespace):
    # TODO(zhuohan): Support pipeline parallelism.
    assert args.pipeline_parallel_size == 1, (
        'Pipeline parallelism is not supported yet.')

    (num_nodes, num_devices_per_node, distributed_init_method,
     all_stage_devices) = (
        initialize_ray_cluster(
            pipeline_parallel_size=args.pipeline_parallel_size,
            tensor_parallel_size=args.tensor_parallel_size))

    world_size = args.pipeline_parallel_size * args.tensor_parallel_size

92
93
94
95
    memory_analyzer = get_memory_analyzer(
        model_name=args.model,
        block_size=args.block_size,
        dtype=args.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
96
        tensor_parallel_size=args.tensor_parallel_size,
97
98
99
    )
    num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
        max_num_batched_tokens=args.max_batch_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
    num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks(
        swap_space=args.swap_space)
102
103
    print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')

Zhuohan Li's avatar
Zhuohan Li committed
104
    # Create a controller for each pipeline stage.
Woosuk Kwon's avatar
Woosuk Kwon committed
105
    controllers: List[Controller] = []
Zhuohan Li's avatar
Zhuohan Li committed
106
    for i in range(args.pipeline_parallel_size):
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        controller = Controller(
Zhuohan Li's avatar
Zhuohan Li committed
108
109
110
111
112
113
            stage_id=i,
            stage_devices=all_stage_devices[i],
            world_size=world_size,
            pipeline_parallel_size=args.pipeline_parallel_size,
            tensor_parallel_size=args.tensor_parallel_size,
            distributed_init_method=distributed_init_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
            model_name=args.model,
            block_size=args.block_size,
116
117
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
118
            dtype=args.dtype,
119
            seed=args.seed,
Zhuohan Li's avatar
Zhuohan Li committed
120
            model_path=args.model_path,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
        )
        controllers.append(controller)

Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
129
    # Create a frontend.
    frontend = Frontend(
        model_name=args.model,
        block_size=args.block_size,
    )

Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
    # Create a scheduler.
    scheduler = Scheduler(
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        frontend=frontend,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
        controllers=controllers,
        block_size=args.block_size,
135
136
137
        num_gpu_blocks=num_gpu_blocks,
        num_cpu_blocks=num_cpu_blocks,
        max_num_batched_tokens=args.max_batch_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
141
142
143
    )
    # Connect the controllers.
    for i in range(len(controllers) - 1):
        controllers[i].set_next(controllers[i + 1])
    controllers[-1].set_next(scheduler)

144
    # Test the following inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
145
    test_inputs = [
146
147
148
        ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
        ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
        ('The future of cloud computing is', {}),   # Use default parameters.
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
    ]
    while True:
Woosuk Kwon's avatar
Woosuk Kwon committed
151
        if test_inputs:
152
153
            text, sampling_params = test_inputs.pop(0)
            frontend.query(text, **sampling_params)
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        scheduler.step()
155
        if not (scheduler.pending or scheduler.running or test_inputs):
Woosuk Kwon's avatar
Woosuk Kwon committed
156
            break
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159


if __name__ == '__main__':
Zhuohan Li's avatar
Zhuohan Li committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    parser = argparse.ArgumentParser(description='CacheFlow server')
    # Model arguments
    parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
    parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
                        help='model path to download and load the weights')
    # Parallel arguments
    parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
    parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
    # KV cache arguments
    parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
    # NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
    parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
    # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
    parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
    args = parser.parse_args()

    main(args)