worker.py 29.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4
import gc
5
import os
6
from contextlib import nullcontext
7
from typing import Dict, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9

import torch
10
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import vllm.envs as envs
13
14
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
15
from vllm.device_allocator.cumem import CuMemAllocator
16
from vllm.distributed import (ensure_model_parallel_initialized,
17
18
                              init_distributed_environment,
                              set_custom_all_reduce)
19
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
20
from vllm.logger import init_logger
21
from vllm.lora.request import LoRARequest
22
from vllm.model_executor import set_random_seed
23
from vllm.model_executor.layers.sampler import SamplerOutput
24
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
25
from vllm.platforms import current_platform
26
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
27
                           SequenceGroupMetadata, SequenceGroupMetadataDelta)
28
29
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
                        memory_profiling)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.worker.cache_engine import CacheEngine
31
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
32
33
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
                                     WorkerInput)
Woosuk Kwon's avatar
Woosuk Kwon committed
34

35
36
logger = init_logger(__name__)

37

38
class Worker(LocalOrDistributedWorkerBase):
39
40
41
42
43
44
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47

    def __init__(
        self,
48
        vllm_config: VllmConfig,
49
50
51
52
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
53
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
54
    ) -> None:
55
        WorkerBase.__init__(self, vllm_config)
56
        self.parallel_config.rank = rank
57
        self.local_rank = local_rank
58
59
        self.rank = rank
        self.distributed_init_method = distributed_init_method
60
        self.is_driver_worker = is_driver_worker
61
62
63
64
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
65

66
67
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
68
69
        speculative_config = self.speculative_config
        model_config = self.model_config
70
        speculative_args = {} if speculative_config is None \
71
72
            or (speculative_config.draft_model_config.hf_config.model_type ==
                model_config.hf_config.model_type) \
73
            or (speculative_config.draft_model_config.hf_config.model_type
74
75
76
77
                not in ("medusa",
                        "mlp_speculator",
                        "eagle",
                        "deepseek_mtp",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
78
                        "glm4_moe_mtp",
79
                        "mimo_mtp",
80
81
                        "ernie_mtp",
                        "qwen3_next_mtp")) \
82
                    else {"return_hidden_states": True}
83

84
        self.model_runner: GPUModelRunnerBase = ModelRunner(
85
            vllm_config=self.vllm_config,
86
            kv_cache_dtype=self.cache_config.cache_dtype,
87
            is_driver_worker=is_driver_worker,
88
            **speculative_args,
89
        )
90
91
92
        if model_runner_cls is not None:
            self.model_runner = model_runner_cls(self.model_runner)

93
        # Uninitialized cache engine. Will be initialized by
94
        # initialize_cache.
95
        self.cache_engine: List[CacheEngine]
96
        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
97
        self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
98

99
100
101
        # Buffers saved before sleep
        self._sleep_saved_buffers: Dict[str, torch.Tensor] = {}

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None

    def start_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
128
129
130
131
        # only print profiler results on rank 0
        if self.local_rank == 0:
            print(self.profiler.key_averages().table(
                sort_by="self_cuda_time_total"))
132

133
134
    def sleep(self, level: int = 1) -> None:
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
135
136
137
138
139
140
141
142
143

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
                name: buffer.cpu().clone()
                for name, buffer in model.named_buffers()
            }

144
145
146
147
148
149
150
151
152
153
154
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

155
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
156
        allocator = CuMemAllocator.get_instance()
157
        allocator.wake_up(tags=tags)
158

159
160
161
162
163
164
165
166
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

167
    def init_device(self) -> None:
168
169
170
171
172
173
174
175
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
176

177
178
179
180
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
181

182
            _check_if_gpu_supports_dtype(self.model_config.dtype)
183
            gc.collect()
184
            torch.cuda.empty_cache()
185
186
            torch.cuda.reset_peak_memory_stats()
            self.baseline_snapshot = MemorySnapshot()
187
188
189
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
190
        # Initialize the distributed environment.
191
        init_worker_distributed_environment(self.vllm_config, self.rank,
192
193
                                            self.distributed_init_method,
                                            self.local_rank)
194
        # Set random seed.
195
        set_random_seed(self.model_config.seed)
196
197

    def load_model(self):
198
199
200
201
202
203
204
205
206
207
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be "
                "used for one instance per process.")
            context = allocator.use_memory_pool(tag="weights")
        else:
            context = nullcontext()
        with context:
            self.model_runner.load_model()
208

209
210
211
212
213
214
215
216
217
218
219
220
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

221
222
223
224
225
226
227
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    @torch.inference_mode()
    def determine_available_kv_cache_memory(self,
                                            total_gpu_memory: int) -> float:
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            GiB = lambda b: b / GiB_bytes
            msg = (
                f"Initial free memory "
                f"{GiB(self.baseline_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
                "KV Cache as specified by kv_cache_memory_bytes config and "
                "skipped memory profiling. This does does not respect the "
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
                "correspondingly.")
            logger.info(msg)
            return self.cache_config.kv_cache_memory_bytes

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        with memory_profiling(
                self.baseline_snapshot,
                weights_memory=self.model_runner.model_memory_usage) as result:
            self.model_runner.profile_run()

        self.non_torch_memory = result.non_torch_increase
        self.peak_activation_memory = result.torch_peak_increase

        self._assert_memory_footprint_increased_during_profiling()

        self.requested_memory = total_gpu_memory * \
            self.cache_config.gpu_memory_utilization

        self.available_kv_cache_memory = (self.requested_memory -
                                          result.non_kv_cache_memory)

        msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
               "the current vLLM instance can use "
               "total_gpu_memory "
               f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
               " x gpu_memory_utilization "
               f"({self.cache_config.gpu_memory_utilization:.2f})"
               f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n"
               "model weights take "
               f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
               " non_torch_memory takes "
               f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
               " PyTorch activation peak memory takes "
               f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
               " the rest of the memory reserved for KV Cache is "
               f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.")

        logger.info(msg)
        return self.available_kv_cache_memory

289
    @torch.inference_mode()
290
291
292
293
294
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
295
        Then, it calculates the maximum possible number of GPU and CPU blocks
296
297
        that can be allocated with the remaining free memory.

298
299
300
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
301
        """
302
303
304
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
305
306
307
        torch.cuda.reset_peak_memory_stats()

        free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
308
309
        available_kv_cache_memory = self.determine_available_kv_cache_memory(
            total_gpu_memory)
310
311
312

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
313
        cache_block_size = self.get_cache_block_size_bytes()
314
315
316
317
        if cache_block_size == 0:
            num_gpu_blocks = 0
            num_cpu_blocks = 0
        else:
318
            num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
319
320
            num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                                 cache_block_size)
321
322
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
323
324

        # Final cleanup
325
        gc.collect()
326

327
328
        return num_gpu_blocks, num_cpu_blocks

329
330
331
    def _assert_memory_footprint_increased_during_profiling(self):
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
332
333
334
        free_gpu_memory, total = torch.cuda.mem_get_info()
        cuda_memory = total - free_gpu_memory
        assert self.baseline_snapshot.cuda_memory < cuda_memory, (
335
            "Error in memory profiling. "
336
337
338
            f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
            f"currently used memory {cuda_memory}. "
            f"This happens when the GPU memory was "
339
340
            "not properly cleaned up before initializing the vLLM instance.")

341
342
343
344
345
346
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
347
348
349
350
351
        raise_if_cache_size_invalid(
            num_gpu_blocks, self.cache_config.block_size,
            self.cache_config.is_attention_free,
            self.model_config.max_model_len,
            self.parallel_config.pipeline_parallel_size)
352
353
354
355

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

356
357
358
359
360
361
362
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self._init_cache_engine()
363
364
365
366
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
367
368
369
370
371
372
373
374
375
        self.cache_engine = [
            CacheEngine(self.cache_config, self.model_config,
                        self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.gpu_cache = [
            self.cache_engine[ve].gpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        shared_kv_cache_layers: dict[str, str] = {}

        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)

        for layer_name, attn_module in attn_layers.items():
            if (kv_tgt_layer :=
                    attn_module.kv_sharing_target_layer_name) is not None:
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                shared_kv_cache_layers[layer_name] = kv_tgt_layer

397
        bind_kv_cache(self.compilation_config.static_forward_context,
398
                      self.gpu_cache, shared_kv_cache_layers)
Woosuk Kwon's avatar
Woosuk Kwon committed
399

400
    def _warm_up_model(self) -> None:
401
402
403
404
405
406
407
408
409
410
411
412
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
            self.model_runner._dummy_run(size)
413
414

        cuda_graph_memory_bytes = 0
415
        if not self.model_config.enforce_eager:
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            cuda_graph_memory_bytes = self.model_runner.capture_model(
                self.gpu_cache)

        if (self.cache_config.kv_cache_memory_bytes is None
                and hasattr(self, "peak_activation_memory")):
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)
            non_kv_cache_memory = (self.model_runner.model_memory_usage +
                                   self.peak_activation_memory +
                                   self.non_torch_memory +
                                   cuda_graph_memory_bytes)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
            kv_cache_memory_bytes_to_gpu_limit = (
                self.baseline_snapshot.free_memory - non_kv_cache_memory -
                redundancy_buffer_memory)
            kv_cache_memory_bytes_to_requested_limit = (
                int(self.requested_memory) - non_kv_cache_memory -
                redundancy_buffer_memory)

            msg = (
                f"Free memory on device "
                f"({GiB(self.baseline_snapshot.free_memory)}/"
                f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
                f"requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
                f"utilize gpu memory. Current kv cache memory in use is "
                f"{int(self.available_kv_cache_memory)} bytes.")
            logger.info(msg)

465
466
467
468
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

469
470
471
472
473
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
474
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
475
        return self.gpu_cache
476
477

    @torch.inference_mode()
478
479
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
480
        virtual_engine = execute_model_req.virtual_engine
481
        num_steps = execute_model_req.num_steps
482
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
483
484
485
486
487
488
489
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
490
                                          dtype=torch.int64).view(-1, 2)
491
492
493
494
495
496
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
497

498
499
500
501
502
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
503
            virtual_engine=virtual_engine,
504
            num_steps=num_steps,
505
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
506

507
    @torch.inference_mode()
508
    def execute_worker(self, worker_input: WorkerInput) -> None:
509
        virtual_engine = worker_input.virtual_engine
510
511
512
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
513
514
            self.cache_engine[virtual_engine].swap_in(
                worker_input.blocks_to_swap_in)
515
516
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
517
518
            self.cache_engine[virtual_engine].swap_out(
                worker_input.blocks_to_swap_out)
519
520
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
521
            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    def _get_cached_seq_group_metadata(
            self,
            seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                SequenceGroupMetadataDelta]],
            finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
        """Return a list of cached Sequence Group Metadata after updating its
        state.

        It is used because scheduler only sends delta to workers to reduce
        the data payload size. The function also cleans up cache based on
        a given `finished_request_ids`.
        """
        new_seq_group_metadata_list = []
        for metadata_or_delta in seq_group_metadata_list:
            request_id = metadata_or_delta.request_id
            if request_id not in self._seq_group_metadata_cache:
                # The first prefill.
                assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                self._seq_group_metadata_cache[request_id] = metadata_or_delta
            else:
                # The first prefill is already cached.
                if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
                    self._seq_group_metadata_cache[request_id].apply_delta(
                        metadata_or_delta)
                else:
                    # If metadata snapshot is sent again, it is
                    # preempted. Reset the cache because we need to start
                    # from scratch.
                    assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                    self._seq_group_metadata_cache[
                        request_id] = metadata_or_delta

            new_seq_group_metadata_list.append(
                self._seq_group_metadata_cache[request_id])

        # Clean up finished ids
        for finished_id in finished_request_ids:
            del self._seq_group_metadata_cache[finished_id]

        return new_seq_group_metadata_list

    def _execute_model_spmd(
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Optional[List[SamplerOutput]]:
        if execute_model_req is not None:
            new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.finished_requests_ids)

            execute_model_req.seq_group_metadata_list = (
                new_seq_group_metadata_list)
        output = super()._execute_model_spmd(execute_model_req,
                                             intermediate_tensors)
        return output

580
581
582
583
584
585
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

586
587
588
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

589
590
591
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

592
593
594
595
596
597
598
599
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

600
    def get_cache_block_size_bytes(self) -> int:
601
602
        """Get the size of the KV cache block size in bytes.
        """
603
        return CacheEngine.get_cache_block_size(self.cache_config,
604
605
606
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
607

608
def init_worker_distributed_environment(
609
    vllm_config: VllmConfig,
610
    rank: int,
611
    distributed_init_method: Optional[str] = None,
612
    local_rank: int = -1,
613
614
) -> None:
    """Initialize the distributed environment."""
615
    parallel_config = vllm_config.parallel_config
616
617
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

618
    init_distributed_environment(parallel_config.world_size, rank,
619
620
                                 distributed_init_method, local_rank,
                                 current_platform.dist_backend)
621
622
623
624
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
        parallel_config.decode_context_parallel_size)
625

626
627
    ensure_kv_transfer_initialized(vllm_config)

628

629
630
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
631
632
633
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
634
            gpu_name = current_platform.get_device_name()
635
636
637
638
639
640
641

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

642
643
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
644
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
645
                "You can use float16 instead by explicitly setting the "
Woosuk Kwon's avatar
Woosuk Kwon committed
646
                "`dtype` flag in CLI, for example: --dtype=half.")
647
648


649
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
650
                                max_model_len, pipeline_parallel_size) -> None:
651
652
    if is_attention_free and num_gpu_blocks != 0:
        raise ValueError("No memory should be allocated for the cache blocks "
653
                         f"for an attention-free model, but {num_gpu_blocks} "
654
655
                         "blocks are allocated.")
    if not is_attention_free and num_gpu_blocks <= 0:
656
657
658
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
659
    max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
660
    if not is_attention_free and max_model_len > max_seq_len:
661
662
663
664
665
666
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")