gpu_worker.py 39.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from collections.abc import Callable
8
from contextlib import AbstractContextManager, nullcontext
9
from datetime import timedelta
10
from types import NoneType
11
from typing import TYPE_CHECKING, Any
12

13
import numpy as np
14
import torch
15
import torch.nn as nn
16

17
import vllm.envs as envs
18
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
19
from vllm.config.compilation import CompilationMode
20
21
22
23
24
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
25
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
26
from vllm.distributed.eplb.eplb_utils import override_envs_for_eplb
27
28
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
29
    ensure_kv_transfer_shutdown,
30
31
32
    get_kv_transfer_group,
    has_kv_transfer_group,
)
33
from vllm.distributed.parallel_state import (
34
    Handle,
35
36
37
    get_pp_group,
    get_tp_group,
)
38
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
39
from vllm.logger import init_logger
40
from vllm.lora.request import LoRARequest
41
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
42
from vllm.platforms import current_platform
43
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
44
from vllm.sequence import IntermediateTensors
45
from vllm.tasks import SupportedTask
46
from vllm.tracing import instrument
47
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
48
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
49
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
50
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
51
52
53
54
55
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
56
from vllm.v1.utils import compute_iteration_details, report_usage_stats
57
from vllm.v1.worker.utils import is_residual_scattered_for_sp
58
from vllm.v1.worker.worker_base import WorkerBase
59
from vllm.v1.worker.workspace import init_workspace_manager
60

61
from ...model_executor.model_loader import TensorizerLoader
62
from .gpu.warmup import warmup_kernels
63
64
from .utils import request_memory

65
66
67
logger = init_logger(__name__)

if TYPE_CHECKING:
68
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
69
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
70
71


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class AsyncIntermediateTensors(IntermediateTensors):
    """IntermediateTensors with lazy comm synchronization"""

    def __init__(
        self,
        tensors: dict[str, torch.Tensor],
        comm_handles: list[Handle] | None = None,
        comm_postprocess: list[Callable[[], None]] | None = None,
    ) -> None:
        super().__init__(tensors)
        self._comm_handles = comm_handles
        self._comm_postprocess = comm_postprocess
        self._comm_waited = False

    def wait_for_comm(self) -> None:
        if self._comm_waited:
            return
        if self._comm_handles:
            for handle in self._comm_handles:
                handle.wait()
        if self._comm_postprocess:
            for fn in self._comm_postprocess:
                fn()
        self._comm_waited = True

    def __getattribute__(self, name: str):
        # ensure `.tensors` is ready before use
        if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
            object.__getattribute__(self, "wait_for_comm")()
        return object.__getattribute__(self, name)


104
class Worker(WorkerBase):
105
106
    def __init__(
        self,
107
        vllm_config: VllmConfig,
108
109
110
        local_rank: int,
        rank: int,
        distributed_init_method: str,
111
        is_driver_worker: bool = False,
112
    ):
113
114
115
116
117
118
119
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
120

121
122
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
123
        torch.set_float32_matmul_precision(precision)
124

125
126
127
128
        from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor

        self.elastic_ep_executor = ElasticEPScalingExecutor(self)

129
130
131
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

132
133
134
135
136
137
138
139
140
141
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

142
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
143
144
        # Profiler wrapper is created lazily in profile() when start is called,
        # so we have all the information needed for proper trace naming.
145
        self.profiler: Any | None = None
146
147
148
149
150
        self.profiler_config = vllm_config.profiler_config

        # Only validate profiler config is valid, don't instantiate yet
        if self.profiler_config.profiler not in ("torch", "cuda", None):
            raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
151

Woosuk Kwon's avatar
Woosuk Kwon committed
152
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
153
154
        # pending non-blocking PP send work from the previous iteration
        self._pp_send_work: list[Handle] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
155

156
    def sleep(self, level: int = 1) -> None:
157
158
        from vllm.device_allocator.cumem import CuMemAllocator

159
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
160
161
162
163
164

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
165
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
166
167
            }

168
        allocator = CuMemAllocator.get_instance()
169
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
170
171
172
173
174
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
175
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
176
177
            format_gib(freed_bytes),
            format_gib(used_bytes),
178
        )
179

180
    def wake_up(self, tags: list[str] | None = None) -> None:
181
182
        from vllm.device_allocator.cumem import CuMemAllocator

183
        allocator = CuMemAllocator.get_instance()
184
        allocator.wake_up(tags)
185

186
187
188
189
190
191
192
193
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

194
195
196
197
198
199
200
201
202
203
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

204
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
205
206
207
208
209
210
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
211
212
                    "Sleep mode can only be used for one instance per process."
                )
213
            return allocator.use_memory_pool(tag=tag)
214
        else:
215
            return nullcontext()
216

217
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
218
219
220
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

221
    @instrument(span_name="Init device")
222
    def init_device(self):
223
        if self.device_config.device_type == "cuda":
224
225
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
226
            parallel_config = self.parallel_config
227
            if (
228
229
230
231
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
232
233
234
235
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
236
                    dp_local_rank = self.parallel_config.data_parallel_index
237
238
239
240
241
242
243
244
245
246
247

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
248
249
250
251
252
253
254
255
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
256

257
            self.device = torch.device(f"cuda:{self.local_rank}")
258
            current_platform.set_device(self.device)
259

260
            current_platform.check_if_supports_dtype(self.model_config.dtype)
261
262
263
264
265

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
266
267
268
269
270
271
272
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
273

274
275
276
            if self.use_v2_model_runner:
                logger.info_once("Using V2 Model Runner", scope="local")

277
278
279
280
            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
281
            gc.collect()
282
            torch.accelerator.empty_cache()
283
284

            # take current memory snapshot
285
286
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
287
288
289
290
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
291
        else:
292
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
293

294
295
296
297
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

298
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
302
303
304
305
306
307
308
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
309
310
311
312
313
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
314

315
316
317
318
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

319
320
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
321
    def load_model(self) -> None:
322
323
324
325
326
327
328
329
330
331
332
333
        dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
        if dummy_weights:
            (
                expanded_physical_to_logical,
                num_logical_experts,
                old_num_physical_experts,
            ) = self.elastic_ep_executor.receive_expert_mapping()
            num_physical_experts = expanded_physical_to_logical.shape[1]
            self.parallel_config.eplb_config.num_redundant_experts = (
                num_physical_experts - num_logical_experts
            )

334
335
336
337
        with (
            self._maybe_get_memory_pool_context(tag="weights"),
            set_current_vllm_config(self.vllm_config),
        ):
338
339
340
341
342
343
344
            self.model_runner.load_model(load_dummy_weights=dummy_weights)

        if dummy_weights:
            self.model_runner.setup_eplb_from_mapping(
                expanded_physical_to_logical, old_num_physical_experts
            )
            self.model_runner.eep_eplb_suppressed = True
345

346
347
348
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

349
350
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
351

352
    @torch.inference_mode()
353
    def determine_available_memory(self) -> int:
354
        """Profiles the peak memory usage of the model to determine how much
355
        memory can be used for KV cache without OOMs.
356
357

        The engine will first conduct a profiling of the existing memory usage.
358
        Then, it calculates the free memory that can be used for KV cache in
359
        bytes.
360

361
362
363
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
364
        """
365
366
367
368
369
370
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
371
372
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
373
                "KV Cache as specified by kv_cache_memory_bytes config and "
374
                "skipped memory profiling. This does not respect the "
375
376
377
378
379
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
380
381
                "correspondingly."
            )
382
383
384
            logger.info(msg)
            return kv_cache_memory_bytes

385
386
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
387
        with memory_profiling(
388
389
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
390
        ) as profile_result:
391
            self.model_runner.profile_run()
392

393
394
395
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

396
        free_gpu_memory = profile_result.after_profile.free_memory
397
398
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
399
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
400
            "Error in memory profiling. "
401
402
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
403
404
405
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
406
407
408
409
410
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
411

412
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
413
        logger.debug(
414
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
415
            format_gib(self.init_snapshot.free_memory),
416
            self.cache_config.gpu_memory_utilization,
417
            format_gib(self.requested_memory),
418
419
        )
        logger.debug(
420
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
421
422
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
423
        )
424
        logger.debug(profile_result)
425
        logger.info_once(
426
            "Available KV cache memory: %s GiB",
427
            format_gib(self.available_kv_cache_memory_bytes),
428
            scope="local",
429
        )
430

431
        return int(self.available_kv_cache_memory_bytes)
432

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

448
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
449
450
        return self.model_runner.get_kv_cache_spec()

451
452
453
454
455
456
457
458
459
460
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
461
            self.model_runner.update_max_model_len(max_model_len)
462
463
        logger.debug("Updated max_model_len to %d", max_model_len)

464
    @instrument(span_name="Allocate KV cache")
465
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
466
        """Allocate GPU KV cache with the specified kv_cache_config."""
467

468
469
470
471
        # Update local config with adjusted num blocks after profiling,
        # so that it's available to the warmup stage.
        self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks

472
473
474
475
476
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
477
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
478

479
        if self.vllm_config.model_config.enable_sleep_mode:
480
481
            from vllm.device_allocator.cumem import CuMemAllocator

482
            allocator = CuMemAllocator.get_instance()
483
484
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
485
486
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
487

488
    @instrument(span_name="Warmup (GPU)")
489
    def compile_or_warm_up_model(self) -> float:
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

515
        # We skip EPLB here since we don't want to record dummy metrics
516
517
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
518
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
519
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
520

521
522
523
524
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

525
        cuda_graph_memory_bytes = 0
526
        if not self.model_config.enforce_eager:
527
528
            cuda_graph_memory_bytes = self.model_runner.capture_model()

529
530
531
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
532
533
534
535
536
537
538
539
540
541
542
543
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
544
545
546
547
548
549
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
550
            kv_cache_memory_bytes_to_gpu_limit = (
551
552
553
554
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
555
            kv_cache_memory_bytes_to_requested_limit = (
556
557
558
559
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
560
561
562

            msg = (
                f"Free memory on device "
563
564
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
565
566
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
567
568
569
570
571
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
572
573
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
574
                f"{kv_cache_memory_bytes_to_requested_limit}` "
575
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
576
577
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
578
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
579
                f"utilize gpu memory. Current kv cache memory in use is "
580
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
581
            )
582

583
            logger.debug(msg)
584

585
586
        if self.use_v2_model_runner:
            # V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
587
            warmup_kernels(self.model_runner, self.execute_model, self.sample_tokens)
588
589
590
591
592
        elif get_pp_group().is_last_rank:
            # V1: Warm up sampler and preallocate memory buffer for logits and other
            # sampling related tensors of max possible shape to avoid memory
            # fragmentation issue.
            # NOTE: This is called after `capture_model` on purpose to prevent
593
            # memory buffers from being cleared by `torch.accelerator.empty_cache`.
594
595
596
597
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
598

599
            # We skip EPLB here since we don't want to record dummy metrics
600
601
602
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
603
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
604
            )
605
606
607
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
608
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
609

610
611
612
613
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

614
615
        return self.compilation_config.compilation_time

616
617
618
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

619
620
621
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

622
623
624
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

625
626
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
627

628
629
630
631
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

632
633
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
634
635
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
636
637
638
        if not self.profiler:
            return nullcontext()

639
640
        self.profiler.step()

641
642
643
644
645
646
647
648
649
650
651
652
653
654
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
655
        )
656
        return self.profiler.annotate_context_manager(annotation)
657

658
659
    @torch.inference_mode()
    def sample_tokens(
660
        self, grammar_output: "GrammarOutput | None"
661
662
663
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

664
665
    @torch.inference_mode()
    def execute_model(
666
        self, scheduler_output: "SchedulerOutput"
667
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
668
669
670
671
672
673
        # ensure any previous non-blocking PP sends are complete
        if self._pp_send_work:
            for handle in self._pp_send_work:
                handle.wait()
            self._pp_send_work = []

674
        intermediate_tensors = None
675
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
676
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
677
678
679
680
681
682
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
683
            and compilation_config.pass_config.enable_sp
684
685
686
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
687
            assert not self.use_v2_model_runner
688
689
690
691
692
693
694
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
695
            _, batch_desc, _, _, _ = (
696
697
698
699
700
701
702
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
703
            )
704
705
706
707
708
709
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

710
        if forward_pass and not get_pp_group().is_first_rank:
711
712
713
714
715
            tensor_dict, comm_handles, comm_postprocess = (
                get_pp_group().irecv_tensor_dict(
                    all_gather_group=get_tp_group(),
                    all_gather_tensors=all_gather_tensors,
                )
716
            )
717
            assert tensor_dict is not None
718
719
720
721
722
            intermediate_tensors = AsyncIntermediateTensors(
                tensor_dict,
                comm_handles=comm_handles,
                comm_postprocess=comm_postprocess,
            )
723

724
725
726
727
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
728
729
730
731
732
733
            if (
                self.use_v2_model_runner
                and self.model_runner.is_pooling_model
                and output is None
            ):
                output = self.model_runner.pool()  # type: ignore
734
735
736
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
737
                return output
738

739
        assert isinstance(output, IntermediateTensors)
740
        parallel_config = self.vllm_config.parallel_config
741
        assert (
742
            parallel_config.distributed_executor_backend != "external_launcher"
743
744
            and not get_pp_group().is_last_rank
        )
745

746
747
        # launch non-blocking send of intermediate tensors
        self._pp_send_work = get_pp_group().isend_tensor_dict(
748
749
750
751
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
752

753
        return None
754

755
    def take_draft_token_ids(self) -> DraftTokenIds | None:
756
757
        return self.model_runner.take_draft_token_ids()

758
759
760
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        # Check if profiling is enabled
        if self.profiler_config is None or self.profiler_config.profiler is None:
761
762
763
764
765
766
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
767

768
        if is_start:
769
770
771
772
773
774
775
776
777
778
779
780
781
            # Generate the trace name by combining prefix with comprehensive rank suffix
            from vllm.distributed.utils import get_worker_rank_suffix

            rank_suffix = get_worker_rank_suffix(global_rank=self.rank)

            # Build the full trace name
            if profile_prefix:
                trace_name = f"{profile_prefix}_{rank_suffix}"
            else:
                trace_name = rank_suffix

            # Create the profiler wrapper only on the first start call
            if self.profiler is None:
782
783
                profiler_type = self.profiler_config.profiler
                if profiler_type == "torch":
784
785
786
787
788
789
790
791
792
                    self.profiler = TorchProfilerWrapper(
                        self.profiler_config,
                        worker_name=trace_name,
                        local_rank=self.local_rank,
                        activities=["CPU", "CUDA"],
                    )
                    logger.debug(
                        "Starting torch profiler with trace name: %s", trace_name
                    )
793
                elif profiler_type == "cuda":
794
795
                    self.profiler = CudaProfilerWrapper(self.profiler_config)
                    logger.debug("Starting CUDA profiler")
796
                else:
797
798
799
800
801
802
803
804
                    # Config validation should prevent this code being reached
                    raise ValueError(
                        f"Invalid profiler value of {self.profiler_config.profiler}"
                    )

            # If profiler already initialized, restart profiling but keep
            # the original trace name from the first initialization.
            self.profiler.start()
805
        else:
806
807
808
            if self.profiler is None:
                logger.warning("Profiler was not started, nothing to stop.")
                return
809
810
            self.profiler.stop()

811
    def execute_dummy_batch(self) -> None:
812
        self.model_runner._dummy_run(1, uniform_decode=True)
813

814
815
816
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

817
818
819
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

820
    def list_loras(self) -> set[int]:
821
822
823
824
825
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

826
827
828
829
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

830
831
832
    def save_sharded_state(
        self,
        path: str,
833
834
        pattern: str | None = None,
        max_size: int | None = None,
835
    ) -> None:
836
        from vllm.model_executor.model_loader import ShardedStateLoader
837

838
839
840
841
842
843
844
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

845
846
847
    def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
        TensorizerLoader.save_model(
            self.get_model(),
848
            tensorizer_config=tensorizer_config,
849
            model_config=self.model_config,
850
        )
851

852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

915
    def shutdown(self) -> None:
916
917
918
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
919
920
        if self.profiler is not None:
            self.profiler.shutdown()
921

922
923
924
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

925
926
927
    def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
        return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)

928
929

def init_worker_distributed_environment(
930
    vllm_config: VllmConfig,
931
    rank: int,
932
    distributed_init_method: str | None = None,
933
    local_rank: int = -1,
934
    backend: str = "nccl",
935
936
) -> None:
    """Initialize the distributed environment."""
937
    attention_config = vllm_config.attention_config
938
    parallel_config = vllm_config.parallel_config
939
940
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

941
    init_batch_invariance(attention_config.backend)
942
    override_envs_for_eplb(parallel_config)
943
944
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

945
    init_method = distributed_init_method or "env://"
946
947
948
949
950

    timeout = None
    if parallel_config.distributed_timeout_seconds is not None:
        timeout = timedelta(seconds=parallel_config.distributed_timeout_seconds)

951
    init_distributed_environment(
952
953
954
955
956
957
        parallel_config.world_size,
        rank,
        init_method,
        local_rank,
        backend,
        timeout,
958
    )
959

960
961
962
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
963
        parallel_config.prefill_context_parallel_size,
964
965
        parallel_config.decode_context_parallel_size,
    )
966
967
968
969

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)