worker.py 15.7 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
11
                         SpeculativeConfig, VisionLanguageConfig)
12
13
from vllm.distributed import (broadcast_tensor_dict,
                              ensure_model_parallel_initialized,
14
15
                              init_distributed_environment,
                              set_custom_all_reduce)
16
from vllm.lora.request import LoRARequest
17
from vllm.model_executor import set_random_seed
18
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.worker.cache_engine import CacheEngine
20
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
21
from vllm.worker.model_runner import ModelRunner
22
from vllm.worker.worker_base import WorkerBase
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24

25
class Worker(WorkerBase):
26
27
28
29
30
31
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34

    def __init__(
        self,
35
36
37
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
38
        device_config: DeviceConfig,
39
        cache_config: CacheConfig,
40
        load_config: LoadConfig,
41
42
43
        local_rank: int,
        rank: int,
        distributed_init_method: str,
44
        lora_config: Optional[LoRAConfig] = None,
45
        vision_language_config: Optional[VisionLanguageConfig] = None,
46
        speculative_config: Optional[SpeculativeConfig] = None,
47
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
48
    ) -> None:
49
50
51
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
52
        self.device_config = device_config
53
        self.cache_config = cache_config
54
        self.local_rank = local_rank
55
56
        self.rank = rank
        self.distributed_init_method = distributed_init_method
57
        self.lora_config = lora_config
58
        self.load_config = load_config
59
60
61
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
62

63
64
65
66
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
67
68
69
70
71
        self.vision_language_config = vision_language_config
        if self.vision_language_config:
            assert not self.lora_config, (
                "To be tested: vision language model with LoRA settings.")

72
73
74
        ModelRunnerClass = (EmbeddingModelRunner if
                            self.model_config.embedding_mode else ModelRunner)
        self.model_runner = ModelRunnerClass(
75
76
77
78
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
79
            cache_config,
80
            load_config=load_config,
81
            lora_config=self.lora_config,
82
            kv_cache_dtype=self.cache_config.cache_dtype,
83
            is_driver_worker=is_driver_worker,
84
85
            vision_language_config=vision_language_config,
        )
86
        # Uninitialized cache engine. Will be initialized by
87
        # initialize_cache.
88
        self.cache_engine: CacheEngine
89
90
        # Initialize gpu_cache as embedding models don't initialize kv_caches
        self.gpu_cache: Optional[List[torch.tensor]] = None
91

92
    def init_device(self) -> None:
93
94
95
96
97
98
99
100
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
101

102
103
104
105
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
106

107
            _check_if_gpu_supports_dtype(self.model_config.dtype)
108
109
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
110
111
112
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
113
        # Initialize the distributed environment.
114
115
116
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
117
        # Set random seed.
118
        set_random_seed(self.model_config.seed)
119
120

    def load_model(self):
121
        self.model_runner.load_model()
122

123
124
125
126
127
128
129
130
131
132
133
134
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

135
    @torch.inference_mode()
136
137
138
139
140
141
142
143
144
145
146
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
147
        """
148
149
150
151
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

152
153
154
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
155
156
157
158

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
159
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
160
161
162
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
163
164
165
        assert peak_memory > 0, (
            "Error in memory profiling. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")
166

167
        cache_block_size = self.get_cache_block_size_bytes()
168
        num_gpu_blocks = int(
169
170
171
172
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
173
174
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
175
176
177
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
178
179
180
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
199
200
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
        self.gpu_cache = self.cache_engine.gpu_cache

203
    def _warm_up_model(self) -> None:
204
205
206
207
208
209
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

210
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
211
        self,
212
213
        blocks_to_swap_in: torch.Tensor,
        blocks_to_swap_out: torch.Tensor,
214
        blocks_to_copy: torch.Tensor,
215
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
216
        # Issue cache operations.
217
        if blocks_to_swap_in.numel() > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
218
            self.cache_engine.swap_in(blocks_to_swap_in)
219
        if blocks_to_swap_out.numel() > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
220
            self.cache_engine.swap_out(blocks_to_swap_out)
221
        if blocks_to_copy.numel() > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
222
            self.cache_engine.copy(blocks_to_copy)
223
224
225
226

    @torch.inference_mode()
    def execute_model(
        self,
227
        execute_model_req: Optional[ExecuteModelRequest] = None
228
    ) -> List[Union[SamplerOutput, PoolerOutput]]:
229
230
231
        if not self.is_driver_worker:
            self._execute_model_non_driver()
            return []
232

233
        if execute_model_req is None:
234
235
236
237
238
239
240
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
            return []
241

242
243
244
245
246
247
248
249
250
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list
        num_seq_groups = len(seq_group_metadata_list)
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
251
                                          dtype=torch.int64).view(-1, 2)
252
253
254
255
256
257
258
259
260
261
262
263
264
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
        data: Dict[str, Any] = {
            "num_seq_groups": num_seq_groups,
            "blocks_to_swap_in": blocks_to_swap_in,
            "blocks_to_swap_out": blocks_to_swap_out,
            "blocks_to_copy": blocks_to_copy,
        }
        broadcast_tensor_dict(data, src=0)
265
266

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
267

Woosuk Kwon's avatar
Woosuk Kwon committed
268
        # If there is no input, we don't need to execute the model.
269
        if num_seq_groups == 0:
270
            return []
Woosuk Kwon's avatar
Woosuk Kwon committed
271

272
        output = self.model_runner.execute_model(seq_group_metadata_list,
273
                                                 self.gpu_cache)
274
275
276
277

        # Worker only supports single-step execution. Wrap the output in a list
        # to conform to interface.
        return [output]
Woosuk Kwon's avatar
Woosuk Kwon committed
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
        while self._execute_model_non_driver():
            pass

    def _execute_model_non_driver(self) -> bool:
        """Execute model in parallel worker.

        Returns True iff there are remaining sequences to process.
        """
        assert not self.is_driver_worker
        data = broadcast_tensor_dict(src=0)
        if not data:
            return False

        num_seq_groups = data.get("num_seq_groups", 0)
        blocks_to_swap_in = data.get("blocks_to_swap_in")
        blocks_to_swap_out = data.get("blocks_to_swap_out")
        blocks_to_copy = data.get("blocks_to_copy")
        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)

        # If there is no input, we don't need to execute the model.
        if num_seq_groups == 0:
            return False

        self.model_runner.execute_model(None, self.gpu_cache)
        return True

312
313
314
315
316
317
318
319
320
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

321
322
323
324
325
326
327
328
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

329
    def get_cache_block_size_bytes(self) -> int:
330
331
        """Get the size of the KV cache block size in bytes.
        """
332
        return CacheEngine.get_cache_block_size(self.cache_config,
333
334
335
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
336

337
def init_worker_distributed_environment(
338
339
    parallel_config: ParallelConfig,
    rank: int,
340
    distributed_init_method: Optional[str] = None,
341
    local_rank: int = -1,
342
343
) -> None:
    """Initialize the distributed environment."""
344
345
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

346
347
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
348

349
350
351
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

352

353
354
355
356
357
358
359
360
361
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
362
363
364
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
                                max_model_len) -> None:
    if num_gpu_blocks <= 0:
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
    if max_model_len > max_seq_len:
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")