worker.py 12.5 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Dict, List, Tuple, Set, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, LoRAConfig)
11
from vllm.model_executor import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.model_executor.parallel_utils import cupy_utils
13
from vllm.model_executor.parallel_utils.communication_op import (
14
    broadcast_tensor_dict)
15
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.model_executor.parallel_utils.parallel_state import (
17
    ensure_model_parallel_initialized)
18
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.worker.cache_engine import CacheEngine
20
from vllm.worker.model_runner import ModelRunner
21
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23

Woosuk Kwon's avatar
Woosuk Kwon committed
24
class Worker:
25
26
27
28
29
30
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33

    def __init__(
        self,
34
35
36
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
37
        device_config: DeviceConfig,
38
39
40
        local_rank: int,
        rank: int,
        distributed_init_method: str,
41
        lora_config: Optional[LoRAConfig] = None,
42
        kv_cache_dtype: Optional[str] = "auto",
43
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
44
    ) -> None:
45
46
47
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
48
        self.device_config = device_config
49
        self.local_rank = local_rank
50
51
        self.rank = rank
        self.distributed_init_method = distributed_init_method
52
        self.lora_config = lora_config
53
54
55
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
56

57
58
59
        self.model_runner = ModelRunner(model_config,
                                        parallel_config,
                                        scheduler_config,
60
                                        device_config,
61
                                        lora_config=self.lora_config,
62
                                        kv_cache_dtype=kv_cache_dtype,
63
                                        is_driver_worker=is_driver_worker)
64
65
66
67
68
69
70
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

Woosuk Kwon's avatar
Woosuk Kwon committed
71
    def init_model(self, cupy_port: Optional[int] = None) -> None:
72
73
74
75
76
77
78
79
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
80

81
82
83
84
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
85

86
            _check_if_gpu_supports_dtype(self.model_config.dtype)
87
88
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
89
90
91
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
92
        # Initialize the distributed environment.
93
        init_distributed_environment(self.parallel_config, self.rank,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
                                     cupy_port, self.distributed_init_method)
95
96
        if not self.parallel_config.disable_custom_all_reduce:
            init_custom_ar()
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        # Initialize the model.
98
        set_random_seed(self.model_config.seed)
99
100

    def load_model(self):
101
        self.model_runner.load_model()
102

103
    @torch.inference_mode()
104
105
106
107
108
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
109
        cache_dtype: str,
110
    ) -> Tuple[int, int]:
111
112
113
114
115
116
117
118
        """Profiles the peak memory usage of the model and returns the maximum
        number of GPU and CPU cache blocks that can be allocated.

        Args:
            block_size: The size of the cache block.
            gpu_memory_utilization: The fraction of the total GPU memory to use.
            cpu_swap_space: The size of the CPU swap space in bytes.
        """
119
120
121
122
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

123
124
125
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
126
127
128
129

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
130
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
131
132
133
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
134

135
        cache_block_size = CacheEngine.get_cache_block_size(
136
            block_size, cache_dtype, self.model_config, self.parallel_config)
137
138
139
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
140
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
141
142
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
143
144
145
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
146
147
148
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

149
150
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
151
152
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache
155
        self.model_runner.set_block_size(self.cache_engine.block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
156

157
158
159
160
161
162
163
    def warm_up_model(self) -> None:
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

164
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166
167
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
168
        blocks_to_copy: Dict[int, List[int]],
169
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        # Issue cache operations.
171
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
174
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
177
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
180
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
181

182
        cache_events = self.cache_events if issued_cache_op else None
Woosuk Kwon's avatar
Woosuk Kwon committed
183

184
185
186
187
188
        # Wait for cache operations to finish.
        # TODO(woosuk): Profile swapping overhead and optimize if needed.
        if cache_events is not None:
            for event in cache_events:
                event.wait()
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
        blocks_to_swap_in: Optional[Dict[int, int]] = None,
        blocks_to_swap_out: Optional[Dict[int, int]] = None,
        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
    ) -> Optional[SamplerOutput]:
        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
            num_seq_groups = len(seq_group_metadata_list)
            assert blocks_to_swap_in is not None
            assert blocks_to_swap_out is not None
            assert blocks_to_copy is not None
204
205
206
207
208
209
210
            data = {
                "num_seq_groups": num_seq_groups,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            }
            broadcast_tensor_dict(data, src=0)
211
        else:
212
213
214
215
216
217
218
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_swap_in = data["blocks_to_swap_in"]
            blocks_to_swap_out = data["blocks_to_swap_out"]
            blocks_to_copy = data["blocks_to_copy"]

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
219

Woosuk Kwon's avatar
Woosuk Kwon committed
220
        # If there is no input, we don't need to execute the model.
221
        if num_seq_groups == 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
            return {}

224
        output = self.model_runner.execute_model(seq_group_metadata_list,
225
                                                 self.gpu_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227
        return output

228
229
230
231
232
233
234
235
236
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

Woosuk Kwon's avatar
Woosuk Kwon committed
237

238
def init_distributed_environment(
239
240
    parallel_config: ParallelConfig,
    rank: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
241
    cupy_port: Optional[int],
242
    distributed_init_method: Optional[str] = None,
243
244
) -> None:
    """Initialize the distributed environment."""
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    if cupy_utils.is_initialized():
        cupy_world_size = cupy_utils.get_world_size()
        if cupy_world_size != parallel_config.world_size:
            raise RuntimeError(
                "cupy.distributed is already initialized but the cupy world "
                "size does not match parallel_config.world_size "
                f"({cupy_world_size} vs. {parallel_config.world_size}).")
    elif parallel_config.world_size > 1 and cupy_port is not None:
        # NOTE(woosuk): We don't initialize CuPy process group when world size
        # is 1.
        # TODO(woosuk): Support multi-node connection.
        cupy_utils.init_process_group(
            world_size=parallel_config.world_size,
            rank=rank,
            host="localhost",
            port=cupy_port,
        )

282
283
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
    if cupy_utils.is_initialized():
        cupy_utils.all_reduce(torch.zeros(1).cuda())
286
287
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)
288
289


290
291
292
293
294
295
296
297
298
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")