"docs/vscode:/vscode.git/clone" did not exist on "4634b00c2a6ec85978b0653781ef6b7397d80509"
server.py 7.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import argparse
2
from typing import List, Tuple
Zhuohan Li's avatar
Zhuohan Li committed
3
4
5
import random

import ray
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7

from cacheflow.master.scheduler import Scheduler
8
from cacheflow.models import get_memory_analyzer
Zhuohan Li's avatar
Zhuohan Li committed
9
from cacheflow.worker.controller import Controller, DeviceID
10
11
12
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams

13

14
15
16
17
18
class Server:
    def __init__(
        self,
        model: str,
        model_path: str,
19
        use_dummy_weights: bool,
20
21
22
23
24
25
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
        block_size: int,
        dtype: str,
        seed: int,
        swap_space: int,
26
        max_num_batched_tokens: int,
27
        max_num_sequences: int,
28
29
30
31
32
33
        num_nodes: int,
        num_devices_per_node: int,
        distributed_init_method: str,
        all_stage_devices: List[List[DeviceID]],
        gpu_memory: int,
        cpu_memory: int,
34
35
        collect_stats: bool = False,
        do_memory_analysis: bool = False,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    ):
        self.num_nodes = num_nodes
        self.num_devices_per_node = num_devices_per_node
        self.world_size = pipeline_parallel_size * tensor_parallel_size

        self.memory_analyzer = get_memory_analyzer(
            model_name=model,
            block_size=block_size,
            dtype=dtype,
            gpu_memory=gpu_memory,
            cpu_memory=cpu_memory,
            tensor_parallel_size=tensor_parallel_size,
        )
        self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
50
            max_num_batched_tokens=max_num_batched_tokens)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
            swap_space=swap_space)
        print(f'# GPU blocks: {self.num_gpu_blocks}, '
              f'# CPU blocks: {self.num_cpu_blocks}')

        # Create a controller for each pipeline stage.
        self.controllers: List[Controller] = []
        for i in range(pipeline_parallel_size):
            controller = Controller(
                stage_id=i,
                stage_devices=all_stage_devices[i],
                world_size=self.world_size,
                pipeline_parallel_size=pipeline_parallel_size,
                tensor_parallel_size=tensor_parallel_size,
                distributed_init_method=distributed_init_method,
                model_name=model,
                block_size=block_size,
                num_gpu_blocks=self.num_gpu_blocks,
                num_cpu_blocks=self.num_cpu_blocks,
                dtype=dtype,
                seed=seed,
                model_path=model_path,
73
                use_dummy_weights=use_dummy_weights,
74
                max_num_batched_tokens=max_num_batched_tokens,
75
76
77
78
79
80
81
82
83
            )
            self.controllers.append(controller)

        # Create a scheduler.
        self.scheduler = Scheduler(
            controllers=self.controllers,
            block_size=block_size,
            num_gpu_blocks=self.num_gpu_blocks,
            num_cpu_blocks=self.num_cpu_blocks,
84
            max_num_batched_tokens=max_num_batched_tokens,
85
86
87
            max_num_sequences=max_num_sequences,
            collect_stats=collect_stats,
            do_memory_analysis=do_memory_analysis,
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        )
        # Connect the controllers.
        for i in range(len(self.controllers) - 1):
            self.controllers[i].set_next(self.controllers[i + 1])
        self.controllers[-1].set_next(self.scheduler)

    def add_sequence_groups(
        self,
        sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
    ):
        self.scheduler.add_sequence_groups(sequence_groups)

    def step(self):
        return self.scheduler.step()

    def has_unfinished_requests(self):
104
        return (self.scheduler.waiting or self.scheduler.running or
105
                self.scheduler.swapped)
Zhuohan Li's avatar
Zhuohan Li committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173


def initialize_ray_cluster(
    address: str = 'auto',
    pipeline_parallel_size: int = 1,
    tensor_parallel_size: int = 1,
) -> Tuple[int, int, str, List[List[DeviceID]]]:
    # Connect to a ray cluster.
    ray.init(address=address)

    # Assume we have a uniform cluster that each node has the same number of
    # GPUs for now.
    valid_node_resources = []
    num_devices_per_node = None
    for node in ray.nodes():
        if (not node['Alive']) or node['Resources']['GPU'] <= 0:
            continue
        if num_devices_per_node is None:
            num_devices_per_node = node['Resources']['GPU']
        else:
            assert num_devices_per_node == node['Resources']['GPU'], (
                "The number of GPUs per node is not uniform.")
        for key in node['Resources']:
            if key.startswith('node:'):
                valid_node_resources.append(key)

    num_nodes = len(valid_node_resources)

    assert (pipeline_parallel_size * tensor_parallel_size
            <= num_nodes * num_devices_per_node), (
                "The number of required GPUs exceeds the total number of "
                "available GPUs.")
    if tensor_parallel_size >= num_devices_per_node:
        assert tensor_parallel_size % num_devices_per_node == 0, (
            "The number of tensor parallelism is not divisible by the "
            "number of GPUs per node.")
    else:
        assert num_devices_per_node % tensor_parallel_size == 0, (
            "The number of GPUs per node is not divisible by the number "
            "of tensor parallelism.")

    # Assign GPUs to pipeline stages.
    rank = 0
    current_node_id = 0
    current_device_id = 0
    distributed_init_method = None
    all_stage_devices = []

    for i in range(pipeline_parallel_size):
        stage_devices = []
        for j in range(tensor_parallel_size):
            node_resource = valid_node_resources[current_node_id]
            stage_devices.append((rank, node_resource, current_device_id))
            if distributed_init_method is None:
                ip = node_resource.split("node:")[-1]
                port = random.randint(10000, 20000)
                distributed_init_method = f"tcp://{ip}:{port}"
            rank += 1
            current_device_id += 1
            if current_device_id >= num_devices_per_node:
                current_node_id += 1
                current_device_id = 0
        all_stage_devices.append(stage_devices)

    return (num_nodes, num_devices_per_node, distributed_init_method,
            all_stage_devices)


174
def add_server_arguments(parser: argparse.ArgumentParser):
Zhuohan Li's avatar
Zhuohan Li committed
175
176
177
178
179
    # Model arguments
    parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
    parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
                        help='model path to download and load the weights')
    # Parallel arguments
180
181
    parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
Zhuohan Li's avatar
Zhuohan Li committed
182
    # KV cache arguments
183
    parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
Zhuohan Li's avatar
Zhuohan Li committed
184
    # NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
185
    parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
Zhuohan Li's avatar
Zhuohan Li committed
186
187
188
    # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
189
190
    parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
    parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
191
    parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
192
    return parser