gpu_worker.py 43.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from collections.abc import Callable
8
from contextlib import AbstractContextManager, nullcontext
9
from datetime import timedelta
10
from types import NoneType
11
from typing import TYPE_CHECKING, Any
12

13
import numpy as np
14
import torch
15
import torch.nn as nn
16

17
import vllm.envs as envs
18
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
19
from vllm.config.compilation import CompilationMode
20
21
22
23
24
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
25
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
26
from vllm.distributed.eplb.eplb_utils import override_envs_for_eplb
27
28
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
29
    ensure_kv_transfer_shutdown,
30
31
32
    get_kv_transfer_group,
    has_kv_transfer_group,
)
33
from vllm.distributed.parallel_state import (
34
    Handle,
35
36
37
    get_pp_group,
    get_tp_group,
)
38
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
39
from vllm.logger import init_logger
40
from vllm.lora.request import LoRARequest
41
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
42
from vllm.platforms import current_platform
43
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
44
from vllm.sequence import IntermediateTensors
45
from vllm.tasks import SupportedTask
46
from vllm.tracing import instrument
47
from vllm.utils.mem_constants import GiB_bytes
48
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
49
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
50
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
51
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
52
53
54
55
56
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
57
from vllm.v1.utils import compute_iteration_details, report_usage_stats
58
from vllm.v1.worker.utils import is_residual_scattered_for_sp
59
from vllm.v1.worker.worker_base import WorkerBase
60
from vllm.v1.worker.workspace import init_workspace_manager
61

62
from ...model_executor.model_loader import TensorizerLoader
63
from .gpu.warmup import warmup_kernels
64
65
from .utils import request_memory

66
67
68
logger = init_logger(__name__)

if TYPE_CHECKING:
69
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
70
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
71
72


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class AsyncIntermediateTensors(IntermediateTensors):
    """IntermediateTensors with lazy comm synchronization"""

    def __init__(
        self,
        tensors: dict[str, torch.Tensor],
        comm_handles: list[Handle] | None = None,
        comm_postprocess: list[Callable[[], None]] | None = None,
    ) -> None:
        super().__init__(tensors)
        self._comm_handles = comm_handles
        self._comm_postprocess = comm_postprocess
        self._comm_waited = False

    def wait_for_comm(self) -> None:
        if self._comm_waited:
            return
        if self._comm_handles:
            for handle in self._comm_handles:
                handle.wait()
        if self._comm_postprocess:
            for fn in self._comm_postprocess:
                fn()
        self._comm_waited = True

    def __getattribute__(self, name: str):
        # ensure `.tensors` is ready before use
        if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
            object.__getattribute__(self, "wait_for_comm")()
        return object.__getattribute__(self, name)


105
class Worker(WorkerBase):
106
107
    def __init__(
        self,
108
        vllm_config: VllmConfig,
109
110
111
        local_rank: int,
        rank: int,
        distributed_init_method: str,
112
        is_driver_worker: bool = False,
113
    ):
114
115
116
117
118
119
120
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
121

122
123
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
124
        torch.set_float32_matmul_precision(precision)
125

126
127
128
129
        from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor

        self.elastic_ep_executor = ElasticEPScalingExecutor(self)

130
131
132
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

133
134
135
136
137
138
139
140
141
142
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

143
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
144
145
        # Profiler wrapper is created lazily in profile() when start is called,
        # so we have all the information needed for proper trace naming.
146
        self.profiler: Any | None = None
147
148
149
150
151
        self.profiler_config = vllm_config.profiler_config

        # Only validate profiler config is valid, don't instantiate yet
        if self.profiler_config.profiler not in ("torch", "cuda", None):
            raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
152

Woosuk Kwon's avatar
Woosuk Kwon committed
153
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
154
155
        # pending non-blocking PP send work from the previous iteration
        self._pp_send_work: list[Handle] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
156

157
    def sleep(self, level: int = 1) -> None:
158
159
        from vllm.device_allocator.cumem import CuMemAllocator

160
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
161
162
163
164
165

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
166
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
167
168
            }

169
        allocator = CuMemAllocator.get_instance()
170
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
171
172
173
174
175
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
176
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
177
178
            format_gib(freed_bytes),
            format_gib(used_bytes),
179
        )
180

181
    def wake_up(self, tags: list[str] | None = None) -> None:
182
183
        from vllm.device_allocator.cumem import CuMemAllocator

184
        allocator = CuMemAllocator.get_instance()
185
        allocator.wake_up(tags)
186

187
188
189
190
191
192
193
194
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

195
196
197
198
199
200
201
202
203
204
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

205
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
206
        if not self.vllm_config.model_config.enable_sleep_mode:
207
            return nullcontext()
208

209
210
211
212
213
214
215
216
        from vllm.device_allocator.cumem import CuMemAllocator

        allocator = CuMemAllocator.get_instance()
        if tag == "weights":
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be used for one instance per process."
            )
        return allocator.use_memory_pool(tag=tag)
217

218
    @instrument(span_name="Init device")
219
    def init_device(self):
220
        if self.device_config.device_type == "cuda":
221
222
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
223
            parallel_config = self.parallel_config
224
            if (
225
226
227
228
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
229
230
231
232
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
233
                    dp_local_rank = self.parallel_config.data_parallel_index
234
235
236
237
238
239
240
241

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
242
                assert self.local_rank < torch.accelerator.device_count(), (
243
244
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
245
                visible_device_count = (
246
                    torch.accelerator.device_count() if torch.cuda.is_available() else 0
247
248
249
250
251
252
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
253

254
            self.device = torch.device(f"cuda:{self.local_rank}")
255
            torch.accelerator.set_device_index(self.device)
256

257
            current_platform.check_if_supports_dtype(self.model_config.dtype)
258
259
260
261
262

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
263
264
265
266
267
268
269
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
270

271
272
273
            if self.use_v2_model_runner:
                logger.info_once("Using V2 Model Runner", scope="local")

274
275
276
277
            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
278
            gc.collect()
279
            torch.accelerator.empty_cache()
280
281

            # take current memory snapshot
282
283
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
284
285
286
287
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
288
        else:
289
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
290

291
292
293
294
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

295
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
298
299
300
301
302
303
304
305
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
306
307
308
309
310
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
311

312
313
314
315
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

316
317
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
318
    def load_model(self, *, load_dummy_weights: bool = False) -> None:
319
320
321
322
        with (
            self._maybe_get_memory_pool_context(tag="weights"),
            set_current_vllm_config(self.vllm_config),
        ):
323
            self.model_runner.load_model(load_dummy_weights=load_dummy_weights)
324

325
326
327
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

328
329
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
330

331
    @torch.inference_mode()
332
    def determine_available_memory(self) -> int:
333
        """Profiles the peak memory usage of the model to determine how much
334
        memory can be used for KV cache without OOMs.
335
336

        The engine will first conduct a profiling of the existing memory usage.
337
        Then, it calculates the free memory that can be used for KV cache in
338
        bytes.
339

340
341
342
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
343
        """
344
345
346
347
348
349
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
350
351
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
352
                "KV Cache as specified by kv_cache_memory_bytes config and "
353
                "skipped memory profiling. This does not respect the "
354
355
356
357
358
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
359
360
                "correspondingly."
            )
361
362
363
            logger.info(msg)
            return kv_cache_memory_bytes

364
365
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
366
        with memory_profiling(
367
368
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
369
        ) as profile_result:
370
            self.model_runner.profile_run()
371

372
            profile_torch_peak = torch.accelerator.memory_stats(self.device).get(
373
374
375
376
                "allocated_bytes.all.peak", 0
            )

            # Profile CUDA graph memory if graphs will be captured.
377
378
            # Skip on ROCm/HIP as graph pool handles and mem_get_info behave
            # differently and can produce incorrect/negative estimates.
379
            cudagraph_memory_estimate = 0
380
            if not self.model_config.enforce_eager and not current_platform.is_rocm():
381
382
383
384
385
386
387
388
389
390
391
392
                cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()

        # Use the pre-cudagraph torch peak to avoid double-counting.
        profile_result.torch_peak_increase = (
            profile_torch_peak - profile_result.before_profile.torch_peak
        )
        profile_result.non_kv_cache_memory = (
            profile_result.non_torch_increase
            + profile_result.torch_peak_increase
            + profile_result.weights_memory
        )

393
394
        # On ROCm, cudagraph_memory_estimate is always 0 so this is a no-op.
        # On CUDA, respect the opt-in flag as originally designed.
395
396
397
398
399
400
        cudagraph_memory_estimate_applied = (
            cudagraph_memory_estimate
            if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
            else 0
        )

401
        self.non_torch_memory = profile_result.non_torch_increase
402
        self.peak_activation_memory = profile_result.torch_peak_increase
403
        self.cudagraph_memory_estimate = cudagraph_memory_estimate
404

405
        free_gpu_memory = profile_result.after_profile.free_memory
406
407
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
408
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
409
            "Error in memory profiling. "
410
411
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
412
413
414
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
415
416
417
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
418
419
420
            self.requested_memory
            - profile_result.non_kv_cache_memory
            - cudagraph_memory_estimate_applied
421
        )
422

423
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
424
        logger.debug(
425
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
426
            format_gib(self.init_snapshot.free_memory),
427
            self.cache_config.gpu_memory_utilization,
428
            format_gib(self.requested_memory),
429
430
        )
        logger.debug(
431
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
432
433
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
434
        )
435
        logger.debug(profile_result)
436
        logger.info_once(
437
            "Available KV cache memory: %s GiB",
438
            format_gib(self.available_kv_cache_memory_bytes),
439
            scope="local",
440
        )
441

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        if cudagraph_memory_estimate > 0:
            total_mem = self.init_snapshot.total_memory
            current_util = self.cache_config.gpu_memory_utilization
            cg_util_delta = cudagraph_memory_estimate / total_mem
            if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS:
                equiv_util = round(current_util - cg_util_delta, 4)
                suggested_util = min(
                    round(current_util + cg_util_delta, 4),
                    1.0,
                )
                logger.info(
                    "CUDA graph memory profiling is enabled "
                    "(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). "
                    "This will become the default in v0.19. "
                    "The current --gpu-memory-utilization=%.4f is equivalent "
                    "to --gpu-memory-utilization=%.4f without CUDA graph "
                    "memory profiling. To maintain the same effective KV "
                    "cache size as before, increase "
                    "--gpu-memory-utilization to %.4f.",
                    current_util,
                    equiv_util,
                    suggested_util,
                )
            else:
                suggested_util = min(
                    round(current_util + cg_util_delta, 4),
                    1.0,
                )
                logger.info(
                    "In v0.19, CUDA graph memory profiling will be enabled "
                    "by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), "
                    "which more accurately accounts for CUDA graph memory "
                    "during KV cache allocation. To try it now, set "
                    "VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase "
                    "--gpu-memory-utilization from %.4f to %.4f to maintain "
                    "the same effective KV cache size.",
                    current_util,
                    suggested_util,
                )

482
        return int(self.available_kv_cache_memory_bytes)
483

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

499
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
500
501
        return self.model_runner.get_kv_cache_spec()

502
503
504
505
506
507
508
509
510
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.
        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
511
            self.model_runner.update_max_model_len(max_model_len)
512
513
        logger.debug("Updated max_model_len to %d", max_model_len)

514
    @instrument(span_name="Allocate KV cache")
515
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
516
        """Allocate GPU KV cache with the specified kv_cache_config."""
517

518
519
520
521
        # Update local config with adjusted num blocks after profiling,
        # so that it's available to the warmup stage.
        self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks

522
523
524
525
526
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
527
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
528

529
        if self.vllm_config.model_config.enable_sleep_mode:
530
531
            from vllm.device_allocator.cumem import CuMemAllocator

532
            allocator = CuMemAllocator.get_instance()
533
534
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
535
536
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
537

538
539
540
        if self.model_config.enable_return_routed_experts:
            self.model_runner.init_routed_experts_capturer()

541
542
543
544
545
546
547
548
        # Build KV-zero metadata outside the CuMem pool so the bookkeeping
        # GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
        # allocator and are not discarded during sleep/wake cycles.
        if kv_cache_config.needs_kv_cache_zeroing and hasattr(
            self.model_runner, "_init_kv_zero_meta"
        ):
            self.model_runner._init_kv_zero_meta()

549
    @instrument(span_name="Warmup (GPU)")
550
    def compile_or_warm_up_model(self) -> float:
551
        warmup_sizes: list[int] = []
552
553
554
555
556
557

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
558
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []  # type: ignore[assignment]
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

576
        # We skip EPLB here since we don't want to record dummy metrics
577
578
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
579
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
580
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
581

582
583
584
585
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

586
        cuda_graph_memory_bytes = 0
587
        if not self.model_config.enforce_eager:
588
589
            cuda_graph_memory_bytes = self.model_runner.capture_model()

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        # Compare actual vs estimated CUDA graph memory (if we did profiling)
        if (
            hasattr(self, "cudagraph_memory_estimate")
            and self.cudagraph_memory_estimate > 0
        ):
            GiB = lambda b: round(b / GiB_bytes, 2)
            diff = abs(cuda_graph_memory_bytes - self.cudagraph_memory_estimate)
            logger.info(
                "CUDA graph pool memory: %s GiB (actual), %s GiB (estimated), "
                "difference: %s GiB (%.1f%%).",
                GiB(cuda_graph_memory_bytes),
                GiB(self.cudagraph_memory_estimate),
                GiB(diff),
                100 * diff / max(cuda_graph_memory_bytes, 1),
            )

606
607
608
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
609
610
611
612
613
614
615
616
617
618
619
620
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
621

622
623
624
625
626
627
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
628
            kv_cache_memory_bytes_to_gpu_limit = (
629
630
631
632
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
633
            kv_cache_memory_bytes_to_requested_limit = (
634
635
636
637
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
638
639
640

            msg = (
                f"Free memory on device "
641
642
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
643
644
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
645
646
647
648
649
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
650
651
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
652
                f"{kv_cache_memory_bytes_to_requested_limit}` "
653
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
654
655
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
656
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
657
                f"utilize gpu memory. Current kv cache memory in use is "
658
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
659
            )
660

661
            logger.debug(msg)
662

663
664
        if self.use_v2_model_runner:
            # V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
665
            warmup_kernels(self.model_runner, self.execute_model, self.sample_tokens)
666
667
668
669
670
        elif get_pp_group().is_last_rank:
            # V1: Warm up sampler and preallocate memory buffer for logits and other
            # sampling related tensors of max possible shape to avoid memory
            # fragmentation issue.
            # NOTE: This is called after `capture_model` on purpose to prevent
671
            # memory buffers from being cleared by `torch.accelerator.empty_cache`.
672
673
674
675
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
676

677
            # We skip EPLB here since we don't want to record dummy metrics
678
679
680
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
681
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
682
            )
683
684
685
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
686
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
687

688
689
690
691
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

692
693
        return self.compilation_config.compilation_time

694
695
696
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

697
698
699
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

700
701
702
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

703
704
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
705

706
707
708
709
710
    def get_compilation_match_table(self) -> dict[str, int]:
        from vllm.compilation.passes.vllm_inductor_pass import get_match_table

        return get_match_table()

711
712
713
714
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

715
716
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
717
718
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
719
720
721
        if not self.profiler:
            return nullcontext()

722
723
        self.profiler.step()

724
725
726
727
728
729
730
731
732
733
734
735
736
737
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
738
        )
739
        return self.profiler.annotate_context_manager(annotation)
740

741
742
    @torch.inference_mode()
    def sample_tokens(
743
        self, grammar_output: "GrammarOutput | None"
744
745
746
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

747
748
    @torch.inference_mode()
    def execute_model(
749
        self, scheduler_output: "SchedulerOutput"
750
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
751
752
753
754
755
756
        # ensure any previous non-blocking PP sends are complete
        if self._pp_send_work:
            for handle in self._pp_send_work:
                handle.wait()
            self._pp_send_work = []

757
        intermediate_tensors = None
758
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
759
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
760
761
762
763
764
765
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
766
            and compilation_config.pass_config.enable_sp
767
768
769
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
770
            assert not self.use_v2_model_runner
771
772
773
774
775
776
777
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
778
            _, batch_desc, _, _, _ = (
779
780
781
782
783
784
785
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
786
            )
787
788
789
790
791
792
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

793
        if forward_pass and not get_pp_group().is_first_rank:
794
795
796
797
798
            tensor_dict, comm_handles, comm_postprocess = (
                get_pp_group().irecv_tensor_dict(
                    all_gather_group=get_tp_group(),
                    all_gather_tensors=all_gather_tensors,
                )
799
            )
800
            assert tensor_dict is not None
801
802
803
804
805
            intermediate_tensors = AsyncIntermediateTensors(
                tensor_dict,
                comm_handles=comm_handles,
                comm_postprocess=comm_postprocess,
            )
806

807
808
809
810
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
811
812
813
814
815
816
            if (
                self.use_v2_model_runner
                and self.model_runner.is_pooling_model
                and output is None
            ):
                output = self.model_runner.pool()  # type: ignore
817
818
819
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
820
                return output
821

822
        assert isinstance(output, IntermediateTensors)
823
        parallel_config = self.vllm_config.parallel_config
824
        assert (
825
            parallel_config.distributed_executor_backend != "external_launcher"
826
827
            and not get_pp_group().is_last_rank
        )
828

829
830
        # launch non-blocking send of intermediate tensors
        self._pp_send_work = get_pp_group().isend_tensor_dict(
831
832
833
834
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
835

836
        return None
837

838
    def take_draft_token_ids(self) -> DraftTokenIds | None:
839
840
        return self.model_runner.take_draft_token_ids()

841
842
843
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        # Check if profiling is enabled
        if self.profiler_config is None or self.profiler_config.profiler is None:
844
845
846
847
848
849
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
850

851
        if is_start:
852
853
854
855
856
857
858
859
860
861
862
863
864
            # Generate the trace name by combining prefix with comprehensive rank suffix
            from vllm.distributed.utils import get_worker_rank_suffix

            rank_suffix = get_worker_rank_suffix(global_rank=self.rank)

            # Build the full trace name
            if profile_prefix:
                trace_name = f"{profile_prefix}_{rank_suffix}"
            else:
                trace_name = rank_suffix

            # Create the profiler wrapper only on the first start call
            if self.profiler is None:
865
866
                profiler_type = self.profiler_config.profiler
                if profiler_type == "torch":
867
868
869
870
871
872
873
874
875
                    self.profiler = TorchProfilerWrapper(
                        self.profiler_config,
                        worker_name=trace_name,
                        local_rank=self.local_rank,
                        activities=["CPU", "CUDA"],
                    )
                    logger.debug(
                        "Starting torch profiler with trace name: %s", trace_name
                    )
876
                elif profiler_type == "cuda":
877
878
                    self.profiler = CudaProfilerWrapper(self.profiler_config)
                    logger.debug("Starting CUDA profiler")
879
                else:
880
881
882
883
884
885
886
887
                    # Config validation should prevent this code being reached
                    raise ValueError(
                        f"Invalid profiler value of {self.profiler_config.profiler}"
                    )

            # If profiler already initialized, restart profiling but keep
            # the original trace name from the first initialization.
            self.profiler.start()
888
        else:
889
890
891
            if self.profiler is None:
                logger.warning("Profiler was not started, nothing to stop.")
                return
892
893
            self.profiler.stop()

894
    def execute_dummy_batch(self) -> None:
895
896
        num_tokens = getattr(self.model_runner, "uniform_decode_query_len", 1)
        self.model_runner._dummy_run(num_tokens, uniform_decode=True)
897

898
899
900
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

901
902
903
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

904
    def list_loras(self) -> set[int]:
905
906
907
908
909
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

910
911
912
913
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

914
915
916
    def save_sharded_state(
        self,
        path: str,
917
918
        pattern: str | None = None,
        max_size: int | None = None,
919
    ) -> None:
920
        from vllm.model_executor.model_loader import ShardedStateLoader
921

922
923
924
925
926
927
928
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

929
930
931
    def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
        TensorizerLoader.save_model(
            self.get_model(),
932
            tensorizer_config=tensorizer_config,
933
            model_config=self.model_config,
934
        )
935

936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

999
1000
1001
1002
        # NCCL broadcast/packed path are asynchronous.
        # Sync here so the next step uses the new weights.
        torch.accelerator.synchronize()

1003
    def shutdown(self) -> None:
1004
1005
1006
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
1007
1008
        if self.profiler is not None:
            self.profiler.shutdown()
1009

1010
1011
1012
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

1013
1014
1015
    def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
        return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)

1016
1017

def init_worker_distributed_environment(
1018
    vllm_config: VllmConfig,
1019
    rank: int,
1020
    distributed_init_method: str | None = None,
1021
    local_rank: int = -1,
1022
    backend: str = "nccl",
1023
1024
) -> None:
    """Initialize the distributed environment."""
1025
    attention_config = vllm_config.attention_config
1026
    parallel_config = vllm_config.parallel_config
1027
1028
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

1029
    init_batch_invariance(attention_config.backend)
1030
    override_envs_for_eplb(parallel_config)
1031
1032
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

1033
    init_method = distributed_init_method or "env://"
1034
1035
1036
1037
1038

    timeout = None
    if parallel_config.distributed_timeout_seconds is not None:
        timeout = timedelta(seconds=parallel_config.distributed_timeout_seconds)

1039
    init_distributed_environment(
1040
1041
1042
1043
1044
1045
        parallel_config.world_size,
        rank,
        init_method,
        local_rank,
        backend,
        timeout,
1046
    )
1047

1048
1049
1050
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
1051
        parallel_config.prefill_context_parallel_size,
1052
1053
        parallel_config.decode_context_parallel_size,
    )
1054

1055
    # Init ec connector here before KV caches init
1056
1057
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)