gpu_worker.py 38.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from collections.abc import Callable
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any
11

12
import numpy as np
13
import torch
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
from vllm.distributed.eplb.eplb_utils import override_envs_for_eplb
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
28
    ensure_kv_transfer_shutdown,
29
30
31
    get_kv_transfer_group,
    has_kv_transfer_group,
)
32
from vllm.distributed.parallel_state import (
33
    Handle,
34
35
36
    get_pp_group,
    get_tp_group,
)
37
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
38
from vllm.logger import init_logger
39
from vllm.lora.request import LoRARequest
40
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
41
from vllm.platforms import current_platform
42
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
43
from vllm.sequence import IntermediateTensors
44
from vllm.tasks import SupportedTask
45
from vllm.tracing import instrument
46
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
47
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
48
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
49
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
50
51
52
53
54
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
55
from vllm.v1.utils import compute_iteration_details, report_usage_stats
56
from vllm.v1.worker.utils import is_residual_scattered_for_sp
57
from vllm.v1.worker.worker_base import WorkerBase
58
from vllm.v1.worker.workspace import init_workspace_manager
59

60
from .gpu.warmup import warmup_kernels
61
62
from .utils import request_memory

63
64
65
logger = init_logger(__name__)

if TYPE_CHECKING:
66
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
67
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
68
69


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class AsyncIntermediateTensors(IntermediateTensors):
    """IntermediateTensors with lazy comm synchronization"""

    def __init__(
        self,
        tensors: dict[str, torch.Tensor],
        comm_handles: list[Handle] | None = None,
        comm_postprocess: list[Callable[[], None]] | None = None,
    ) -> None:
        super().__init__(tensors)
        self._comm_handles = comm_handles
        self._comm_postprocess = comm_postprocess
        self._comm_waited = False

    def wait_for_comm(self) -> None:
        if self._comm_waited:
            return
        if self._comm_handles:
            for handle in self._comm_handles:
                handle.wait()
        if self._comm_postprocess:
            for fn in self._comm_postprocess:
                fn()
        self._comm_waited = True

    def __getattribute__(self, name: str):
        # ensure `.tensors` is ready before use
        if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
            object.__getattribute__(self, "wait_for_comm")()
        return object.__getattribute__(self, name)


102
class Worker(WorkerBase):
103
104
    def __init__(
        self,
105
        vllm_config: VllmConfig,
106
107
108
        local_rank: int,
        rank: int,
        distributed_init_method: str,
109
        is_driver_worker: bool = False,
110
    ):
111
112
113
114
115
116
117
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
118

119
120
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
121
        torch.set_float32_matmul_precision(precision)
122

123
124
125
126
        from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor

        self.elastic_ep_executor = ElasticEPScalingExecutor(self)

127
128
129
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

130
131
132
133
134
135
136
137
138
139
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

140
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
141
142
        # Profiler wrapper is created lazily in profile() when start is called,
        # so we have all the information needed for proper trace naming.
143
        self.profiler: Any | None = None
144
145
146
147
148
        self.profiler_config = vllm_config.profiler_config

        # Only validate profiler config is valid, don't instantiate yet
        if self.profiler_config.profiler not in ("torch", "cuda", None):
            raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
149

Woosuk Kwon's avatar
Woosuk Kwon committed
150
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
151
152
        # pending non-blocking PP send work from the previous iteration
        self._pp_send_work: list[Handle] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
    def sleep(self, level: int = 1) -> None:
155
156
        from vllm.device_allocator.cumem import CuMemAllocator

157
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
158
159
160
161
162

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
163
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
164
165
            }

166
        allocator = CuMemAllocator.get_instance()
167
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
168
169
170
171
172
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
173
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
174
175
            format_gib(freed_bytes),
            format_gib(used_bytes),
176
        )
177

178
    def wake_up(self, tags: list[str] | None = None) -> None:
179
180
        from vllm.device_allocator.cumem import CuMemAllocator

181
        allocator = CuMemAllocator.get_instance()
182
        allocator.wake_up(tags)
183

184
185
186
187
188
189
190
191
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

192
193
194
195
196
197
198
199
200
201
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

202
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
203
204
205
206
207
208
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
209
210
                    "Sleep mode can only be used for one instance per process."
                )
211
            return allocator.use_memory_pool(tag=tag)
212
        else:
213
            return nullcontext()
214

215
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
216
217
218
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

219
    @instrument(span_name="Init device")
220
    def init_device(self):
221
        if self.device_config.device_type == "cuda":
222
223
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
224
            parallel_config = self.parallel_config
225
            if (
226
227
228
229
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
230
231
232
233
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
234
                    dp_local_rank = self.parallel_config.data_parallel_index
235
236
237
238
239
240
241
242
243
244
245

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
246
247
248
249
250
251
252
253
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
254

255
            self.device = torch.device(f"cuda:{self.local_rank}")
256
            current_platform.set_device(self.device)
257

258
            current_platform.check_if_supports_dtype(self.model_config.dtype)
259
260
261
262
263

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
264
265
266
267
268
269
270
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
271

272
273
274
            if self.use_v2_model_runner:
                logger.info_once("Using V2 Model Runner", scope="local")

275
276
277
278
            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
279
280
            gc.collect()
            torch.cuda.empty_cache()
281
282

            # take current memory snapshot
283
284
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
285
286
287
288
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
289
        else:
290
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
291

292
293
294
295
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

296
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
299
300
301
302
303
304
305
306
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
307
308
309
310
311
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
312

313
314
315
316
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

317
318
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
319
    def load_model(self) -> None:
320
321
322
323
324
325
326
327
328
329
330
331
        dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
        if dummy_weights:
            (
                expanded_physical_to_logical,
                num_logical_experts,
                old_num_physical_experts,
            ) = self.elastic_ep_executor.receive_expert_mapping()
            num_physical_experts = expanded_physical_to_logical.shape[1]
            self.parallel_config.eplb_config.num_redundant_experts = (
                num_physical_experts - num_logical_experts
            )

332
333
334
335
        with (
            self._maybe_get_memory_pool_context(tag="weights"),
            set_current_vllm_config(self.vllm_config),
        ):
336
337
338
339
340
341
342
            self.model_runner.load_model(load_dummy_weights=dummy_weights)

        if dummy_weights:
            self.model_runner.setup_eplb_from_mapping(
                expanded_physical_to_logical, old_num_physical_experts
            )
            self.model_runner.eep_eplb_suppressed = True
343

344
345
346
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

347
348
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
349

350
    @torch.inference_mode()
351
    def determine_available_memory(self) -> int:
352
        """Profiles the peak memory usage of the model to determine how much
353
        memory can be used for KV cache without OOMs.
354
355

        The engine will first conduct a profiling of the existing memory usage.
356
        Then, it calculates the free memory that can be used for KV cache in
357
        bytes.
358

359
360
361
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
362
        """
363
364
365
366
367
368
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
369
370
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
371
                "KV Cache as specified by kv_cache_memory_bytes config and "
372
                "skipped memory profiling. This does not respect the "
373
374
375
376
377
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
378
379
                "correspondingly."
            )
380
381
382
            logger.info(msg)
            return kv_cache_memory_bytes

383
384
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
385
        with memory_profiling(
386
387
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
388
        ) as profile_result:
389
            self.model_runner.profile_run()
390

391
392
393
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

394
        free_gpu_memory = profile_result.after_profile.free_memory
395
396
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
397
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
398
            "Error in memory profiling. "
399
400
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
401
402
403
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
404
405
406
407
408
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
409

410
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
411
        logger.debug(
412
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
413
            format_gib(self.init_snapshot.free_memory),
414
            self.cache_config.gpu_memory_utilization,
415
            format_gib(self.requested_memory),
416
417
        )
        logger.debug(
418
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
419
420
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
421
        )
422
        logger.debug(profile_result)
423
        logger.info_once(
424
            "Available KV cache memory: %s GiB",
425
            format_gib(self.available_kv_cache_memory_bytes),
426
            scope="local",
427
        )
428

429
        return int(self.available_kv_cache_memory_bytes)
430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

446
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
447
448
        return self.model_runner.get_kv_cache_spec()

449
450
451
452
453
454
455
456
457
458
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
459
            self.model_runner.update_max_model_len(max_model_len)
460
461
        logger.debug("Updated max_model_len to %d", max_model_len)

462
    @instrument(span_name="Allocate KV cache")
463
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
464
        """Allocate GPU KV cache with the specified kv_cache_config."""
465

466
467
468
469
470
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
471
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
472

473
        if self.vllm_config.model_config.enable_sleep_mode:
474
475
            from vllm.device_allocator.cumem import CuMemAllocator

476
            allocator = CuMemAllocator.get_instance()
477
478
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
479
480
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
481

482
    @instrument(span_name="Warmup (GPU)")
483
    def compile_or_warm_up_model(self) -> None:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

509
        # We skip EPLB here since we don't want to record dummy metrics
510
511
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
512
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
513
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
514

515
516
517
518
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

519
        cuda_graph_memory_bytes = 0
520
        if not self.model_config.enforce_eager:
521
522
            cuda_graph_memory_bytes = self.model_runner.capture_model()

523
524
525
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
526
527
528
529
530
531
532
533
534
535
536
537
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
538
539
540
541
542
543
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
544
            kv_cache_memory_bytes_to_gpu_limit = (
545
546
547
548
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
549
            kv_cache_memory_bytes_to_requested_limit = (
550
551
552
553
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
554
555
556

            msg = (
                f"Free memory on device "
557
558
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
559
560
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
561
562
563
564
565
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
566
567
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
568
                f"{kv_cache_memory_bytes_to_requested_limit}` "
569
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
570
571
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
572
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
573
                f"utilize gpu memory. Current kv cache memory in use is "
574
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
575
            )
576

577
            logger.debug(msg)
578

579
580
581
582
583
584
585
586
587
        if self.use_v2_model_runner:
            # V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
            warmup_kernels(self.model_runner)
        elif get_pp_group().is_last_rank:
            # V1: Warm up sampler and preallocate memory buffer for logits and other
            # sampling related tensors of max possible shape to avoid memory
            # fragmentation issue.
            # NOTE: This is called after `capture_model` on purpose to prevent
            # memory buffers from being cleared by `torch.cuda.empty_cache`.
588
589
590
591
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
592

593
            # We skip EPLB here since we don't want to record dummy metrics
594
595
596
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
597
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
598
            )
599
600
601
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
602
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
603

604
605
606
607
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

608
609
610
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

611
612
613
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

614
615
616
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

617
618
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
619

620
621
622
623
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

624
625
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
626
627
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
628
629
630
        if not self.profiler:
            return nullcontext()

631
632
        self.profiler.step()

633
634
635
636
637
638
639
640
641
642
643
644
645
646
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
647
        )
648
        return self.profiler.annotate_context_manager(annotation)
649

650
651
    @torch.inference_mode()
    def sample_tokens(
652
        self, grammar_output: "GrammarOutput | None"
653
654
655
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

656
657
    @torch.inference_mode()
    def execute_model(
658
        self, scheduler_output: "SchedulerOutput"
659
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
660
661
662
663
664
665
        # ensure any previous non-blocking PP sends are complete
        if self._pp_send_work:
            for handle in self._pp_send_work:
                handle.wait()
            self._pp_send_work = []

666
        intermediate_tensors = None
667
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
668
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
669
670
671
672
673
674
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
675
            and compilation_config.pass_config.enable_sp
676
677
678
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
679
            assert not self.use_v2_model_runner
680
681
682
683
684
685
686
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
687
            _, batch_desc, _, _, _ = (
688
689
690
691
692
693
694
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
695
            )
696
697
698
699
700
701
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

702
        if forward_pass and not get_pp_group().is_first_rank:
703
704
705
706
707
            tensor_dict, comm_handles, comm_postprocess = (
                get_pp_group().irecv_tensor_dict(
                    all_gather_group=get_tp_group(),
                    all_gather_tensors=all_gather_tensors,
                )
708
            )
709
            assert tensor_dict is not None
710
711
712
713
714
            intermediate_tensors = AsyncIntermediateTensors(
                tensor_dict,
                comm_handles=comm_handles,
                comm_postprocess=comm_postprocess,
            )
715

716
717
718
719
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
720
721
722
723
724
725
            if (
                self.use_v2_model_runner
                and self.model_runner.is_pooling_model
                and output is None
            ):
                output = self.model_runner.pool()  # type: ignore
726
727
728
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
729
                return output
730

731
        assert isinstance(output, IntermediateTensors)
732
        parallel_config = self.vllm_config.parallel_config
733
        assert (
734
            parallel_config.distributed_executor_backend != "external_launcher"
735
736
            and not get_pp_group().is_last_rank
        )
737

738
739
        # launch non-blocking send of intermediate tensors
        self._pp_send_work = get_pp_group().isend_tensor_dict(
740
741
742
743
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
744

745
        return None
746

747
    def take_draft_token_ids(self) -> DraftTokenIds | None:
748
749
        return self.model_runner.take_draft_token_ids()

750
751
752
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        # Check if profiling is enabled
        if self.profiler_config is None or self.profiler_config.profiler is None:
753
754
755
756
757
758
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
759

760
        if is_start:
761
762
763
764
765
766
767
768
769
770
771
772
773
            # Generate the trace name by combining prefix with comprehensive rank suffix
            from vllm.distributed.utils import get_worker_rank_suffix

            rank_suffix = get_worker_rank_suffix(global_rank=self.rank)

            # Build the full trace name
            if profile_prefix:
                trace_name = f"{profile_prefix}_{rank_suffix}"
            else:
                trace_name = rank_suffix

            # Create the profiler wrapper only on the first start call
            if self.profiler is None:
774
775
                profiler_type = self.profiler_config.profiler
                if profiler_type == "torch":
776
777
778
779
780
781
782
783
784
                    self.profiler = TorchProfilerWrapper(
                        self.profiler_config,
                        worker_name=trace_name,
                        local_rank=self.local_rank,
                        activities=["CPU", "CUDA"],
                    )
                    logger.debug(
                        "Starting torch profiler with trace name: %s", trace_name
                    )
785
                elif profiler_type == "cuda":
786
787
                    self.profiler = CudaProfilerWrapper(self.profiler_config)
                    logger.debug("Starting CUDA profiler")
788
789
790
                else:
                    logger.warning("Unrecognized profiler: %s", profiler_type)
                    return
791
792
793
794
795
                self.profiler.start()
            else:
                # Profiler already initialized. Restart profiling but keep
                # the original trace name from the first initialization.
                self.profiler.start()
796
        else:
797
798
799
            if self.profiler is None:
                logger.warning("Profiler was not started, nothing to stop.")
                return
800
801
            self.profiler.stop()

802
    def execute_dummy_batch(self) -> None:
803
        self.model_runner._dummy_run(1, uniform_decode=True)
804

805
806
807
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

808
809
810
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

811
    def list_loras(self) -> set[int]:
812
813
814
815
816
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

817
818
819
820
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

821
822
823
    def save_sharded_state(
        self,
        path: str,
824
825
        pattern: str | None = None,
        max_size: int | None = None,
826
    ) -> None:
827
        from vllm.model_executor.model_loader import ShardedStateLoader
828

829
830
831
832
833
834
835
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

836
837
838
839
840
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
841
842
            tensorizer_config=tensorizer_config,
        )
843

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

907
    def shutdown(self) -> None:
908
909
910
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
911
912
        if self.profiler is not None:
            self.profiler.shutdown()
913

914
915
916
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

917
918
919
    def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
        return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)

920
921

def init_worker_distributed_environment(
922
    vllm_config: VllmConfig,
923
    rank: int,
924
    distributed_init_method: str | None = None,
925
    local_rank: int = -1,
926
    backend: str = "nccl",
927
928
) -> None:
    """Initialize the distributed environment."""
929
    attention_config = vllm_config.attention_config
930
    parallel_config = vllm_config.parallel_config
931
932
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

933
    init_batch_invariance(attention_config.backend)
934
    override_envs_for_eplb(parallel_config)
935
936
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

937
    init_method = distributed_init_method or "env://"
938
    init_distributed_environment(
939
        parallel_config.world_size, rank, init_method, local_rank, backend
940
    )
941

942
943
944
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
945
        parallel_config.prefill_context_parallel_size,
946
947
        parallel_config.decode_context_parallel_size,
    )
948
949
950
951

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)