gpu_worker.py 39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from collections.abc import Callable
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any
11

12
import numpy as np
13
import torch
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
from vllm.distributed.eplb.eplb_utils import override_envs_for_eplb
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
28
    ensure_kv_transfer_shutdown,
29
30
31
    get_kv_transfer_group,
    has_kv_transfer_group,
)
32
from vllm.distributed.parallel_state import (
33
    Handle,
34
35
36
    get_pp_group,
    get_tp_group,
)
37
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
38
from vllm.logger import init_logger
39
from vllm.lora.request import LoRARequest
40
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
41
from vllm.platforms import current_platform
42
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
43
from vllm.sequence import IntermediateTensors
44
from vllm.tasks import SupportedTask
45
from vllm.tracing import instrument
46
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
47
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
48
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
49
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
50
51
52
53
54
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
55
from vllm.v1.utils import compute_iteration_details, report_usage_stats
56
from vllm.v1.worker.utils import is_residual_scattered_for_sp
57
from vllm.v1.worker.worker_base import WorkerBase
58
from vllm.v1.worker.workspace import init_workspace_manager
59

60
from ...model_executor.model_loader import TensorizerLoader
61
from .gpu.warmup import warmup_kernels
62
63
from .utils import request_memory

64
65
66
logger = init_logger(__name__)

if TYPE_CHECKING:
67
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
68
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
69
70


71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class AsyncIntermediateTensors(IntermediateTensors):
    """IntermediateTensors with lazy comm synchronization"""

    def __init__(
        self,
        tensors: dict[str, torch.Tensor],
        comm_handles: list[Handle] | None = None,
        comm_postprocess: list[Callable[[], None]] | None = None,
    ) -> None:
        super().__init__(tensors)
        self._comm_handles = comm_handles
        self._comm_postprocess = comm_postprocess
        self._comm_waited = False

    def wait_for_comm(self) -> None:
        if self._comm_waited:
            return
        if self._comm_handles:
            for handle in self._comm_handles:
                handle.wait()
        if self._comm_postprocess:
            for fn in self._comm_postprocess:
                fn()
        self._comm_waited = True

    def __getattribute__(self, name: str):
        # ensure `.tensors` is ready before use
        if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
            object.__getattribute__(self, "wait_for_comm")()
        return object.__getattribute__(self, name)


103
class Worker(WorkerBase):
104
105
    def __init__(
        self,
106
        vllm_config: VllmConfig,
107
108
109
        local_rank: int,
        rank: int,
        distributed_init_method: str,
110
        is_driver_worker: bool = False,
111
    ):
112
113
114
115
116
117
118
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
119

120
121
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
122
        torch.set_float32_matmul_precision(precision)
123

124
125
126
127
        from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor

        self.elastic_ep_executor = ElasticEPScalingExecutor(self)

128
129
130
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

131
132
133
134
135
136
137
138
139
140
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

141
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
142
143
        # Profiler wrapper is created lazily in profile() when start is called,
        # so we have all the information needed for proper trace naming.
144
        self.profiler: Any | None = None
145
146
147
148
149
        self.profiler_config = vllm_config.profiler_config

        # Only validate profiler config is valid, don't instantiate yet
        if self.profiler_config.profiler not in ("torch", "cuda", None):
            raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
150

Woosuk Kwon's avatar
Woosuk Kwon committed
151
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
152
153
        # pending non-blocking PP send work from the previous iteration
        self._pp_send_work: list[Handle] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
154

155
    def sleep(self, level: int = 1) -> None:
156
157
        from vllm.device_allocator.cumem import CuMemAllocator

158
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
159
160
161
162
163

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
164
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
165
166
            }

167
        allocator = CuMemAllocator.get_instance()
168
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
169
170
171
172
173
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
174
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
175
176
            format_gib(freed_bytes),
            format_gib(used_bytes),
177
        )
178

179
    def wake_up(self, tags: list[str] | None = None) -> None:
180
181
        from vllm.device_allocator.cumem import CuMemAllocator

182
        allocator = CuMemAllocator.get_instance()
183
        allocator.wake_up(tags)
184

185
186
187
188
189
190
191
192
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

193
194
195
196
197
198
199
200
201
202
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

203
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
204
205
206
207
208
209
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
210
211
                    "Sleep mode can only be used for one instance per process."
                )
212
            return allocator.use_memory_pool(tag=tag)
213
        else:
214
            return nullcontext()
215

216
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
217
218
219
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

220
    @instrument(span_name="Init device")
221
    def init_device(self):
222
        if self.device_config.device_type == "cuda":
223
224
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
225
            parallel_config = self.parallel_config
226
            if (
227
228
229
230
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
231
232
233
234
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
235
                    dp_local_rank = self.parallel_config.data_parallel_index
236
237
238
239
240
241
242
243
244
245
246

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
247
248
249
250
251
252
253
254
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
255

256
            self.device = torch.device(f"cuda:{self.local_rank}")
257
            current_platform.set_device(self.device)
258

259
            current_platform.check_if_supports_dtype(self.model_config.dtype)
260
261
262
263
264

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
265
266
267
268
269
270
271
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
272

273
274
275
            if self.use_v2_model_runner:
                logger.info_once("Using V2 Model Runner", scope="local")

276
277
278
279
            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
280
281
            gc.collect()
            torch.cuda.empty_cache()
282
283

            # take current memory snapshot
284
285
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
286
287
288
289
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
290
        else:
291
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
292

293
294
295
296
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

297
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
301
302
303
304
305
306
307
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
308
309
310
311
312
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
313

314
315
316
317
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

318
319
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
320
    def load_model(self) -> None:
321
322
323
324
325
326
327
328
329
330
331
332
        dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
        if dummy_weights:
            (
                expanded_physical_to_logical,
                num_logical_experts,
                old_num_physical_experts,
            ) = self.elastic_ep_executor.receive_expert_mapping()
            num_physical_experts = expanded_physical_to_logical.shape[1]
            self.parallel_config.eplb_config.num_redundant_experts = (
                num_physical_experts - num_logical_experts
            )

333
334
335
336
        with (
            self._maybe_get_memory_pool_context(tag="weights"),
            set_current_vllm_config(self.vllm_config),
        ):
337
338
339
340
341
342
343
            self.model_runner.load_model(load_dummy_weights=dummy_weights)

        if dummy_weights:
            self.model_runner.setup_eplb_from_mapping(
                expanded_physical_to_logical, old_num_physical_experts
            )
            self.model_runner.eep_eplb_suppressed = True
344

345
346
347
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

348
349
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
350

351
    @torch.inference_mode()
352
    def determine_available_memory(self) -> int:
353
        """Profiles the peak memory usage of the model to determine how much
354
        memory can be used for KV cache without OOMs.
355
356

        The engine will first conduct a profiling of the existing memory usage.
357
        Then, it calculates the free memory that can be used for KV cache in
358
        bytes.
359

360
361
362
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
363
        """
364
365
366
367
368
369
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
370
371
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
372
                "KV Cache as specified by kv_cache_memory_bytes config and "
373
                "skipped memory profiling. This does not respect the "
374
375
376
377
378
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
379
380
                "correspondingly."
            )
381
382
383
            logger.info(msg)
            return kv_cache_memory_bytes

384
385
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
386
        with memory_profiling(
387
388
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
389
        ) as profile_result:
390
            self.model_runner.profile_run()
391

392
393
394
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

395
        free_gpu_memory = profile_result.after_profile.free_memory
396
397
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
398
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
399
            "Error in memory profiling. "
400
401
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
402
403
404
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
405
406
407
408
409
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
410

411
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
412
        logger.debug(
413
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
414
            format_gib(self.init_snapshot.free_memory),
415
            self.cache_config.gpu_memory_utilization,
416
            format_gib(self.requested_memory),
417
418
        )
        logger.debug(
419
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
420
421
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
422
        )
423
        logger.debug(profile_result)
424
        logger.info_once(
425
            "Available KV cache memory: %s GiB",
426
            format_gib(self.available_kv_cache_memory_bytes),
427
            scope="local",
428
        )
429

430
        return int(self.available_kv_cache_memory_bytes)
431

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

447
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
448
449
        return self.model_runner.get_kv_cache_spec()

450
451
452
453
454
455
456
457
458
459
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
460
            self.model_runner.update_max_model_len(max_model_len)
461
462
        logger.debug("Updated max_model_len to %d", max_model_len)

463
    @instrument(span_name="Allocate KV cache")
464
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
465
        """Allocate GPU KV cache with the specified kv_cache_config."""
466

467
468
469
470
471
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
472
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
473

474
        if self.vllm_config.model_config.enable_sleep_mode:
475
476
            from vllm.device_allocator.cumem import CuMemAllocator

477
            allocator = CuMemAllocator.get_instance()
478
479
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
480
481
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
482

483
    @instrument(span_name="Warmup (GPU)")
484
    def compile_or_warm_up_model(self) -> float:
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

510
        # We skip EPLB here since we don't want to record dummy metrics
511
512
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
513
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
514
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
515

516
517
518
519
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

520
        cuda_graph_memory_bytes = 0
521
        if not self.model_config.enforce_eager:
522
523
            cuda_graph_memory_bytes = self.model_runner.capture_model()

524
525
526
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
527
528
529
530
531
532
533
534
535
536
537
538
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
539
540
541
542
543
544
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
545
            kv_cache_memory_bytes_to_gpu_limit = (
546
547
548
549
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
550
            kv_cache_memory_bytes_to_requested_limit = (
551
552
553
554
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
555
556
557

            msg = (
                f"Free memory on device "
558
559
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
560
561
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
562
563
564
565
566
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
567
568
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
569
                f"{kv_cache_memory_bytes_to_requested_limit}` "
570
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
571
572
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
573
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
574
                f"utilize gpu memory. Current kv cache memory in use is "
575
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
576
            )
577

578
            logger.debug(msg)
579

580
581
582
583
584
585
586
587
588
        if self.use_v2_model_runner:
            # V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
            warmup_kernels(self.model_runner)
        elif get_pp_group().is_last_rank:
            # V1: Warm up sampler and preallocate memory buffer for logits and other
            # sampling related tensors of max possible shape to avoid memory
            # fragmentation issue.
            # NOTE: This is called after `capture_model` on purpose to prevent
            # memory buffers from being cleared by `torch.cuda.empty_cache`.
589
590
591
592
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
593

594
            # We skip EPLB here since we don't want to record dummy metrics
595
596
597
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
598
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
599
            )
600
601
602
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
603
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
604

605
606
607
608
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

609
610
        return self.compilation_config.compilation_time

611
612
613
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

614
615
616
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

617
618
619
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

620
621
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
622

623
624
625
626
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

627
628
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
629
630
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
631
632
633
        if not self.profiler:
            return nullcontext()

634
635
        self.profiler.step()

636
637
638
639
640
641
642
643
644
645
646
647
648
649
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
650
        )
651
        return self.profiler.annotate_context_manager(annotation)
652

653
654
    @torch.inference_mode()
    def sample_tokens(
655
        self, grammar_output: "GrammarOutput | None"
656
657
658
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

659
660
    @torch.inference_mode()
    def execute_model(
661
        self, scheduler_output: "SchedulerOutput"
662
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
663
664
665
666
667
668
        # ensure any previous non-blocking PP sends are complete
        if self._pp_send_work:
            for handle in self._pp_send_work:
                handle.wait()
            self._pp_send_work = []

669
        intermediate_tensors = None
670
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
671
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
672
673
674
675
676
677
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
678
            and compilation_config.pass_config.enable_sp
679
680
681
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
682
            assert not self.use_v2_model_runner
683
684
685
686
687
688
689
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
690
            _, batch_desc, _, _, _ = (
691
692
693
694
695
696
697
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
698
            )
699
700
701
702
703
704
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

705
        if forward_pass and not get_pp_group().is_first_rank:
706
707
708
709
710
            tensor_dict, comm_handles, comm_postprocess = (
                get_pp_group().irecv_tensor_dict(
                    all_gather_group=get_tp_group(),
                    all_gather_tensors=all_gather_tensors,
                )
711
            )
712
            assert tensor_dict is not None
713
714
715
716
717
            intermediate_tensors = AsyncIntermediateTensors(
                tensor_dict,
                comm_handles=comm_handles,
                comm_postprocess=comm_postprocess,
            )
718

719
720
721
722
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
723
724
725
726
727
728
            if (
                self.use_v2_model_runner
                and self.model_runner.is_pooling_model
                and output is None
            ):
                output = self.model_runner.pool()  # type: ignore
729
730
731
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
732
                return output
733

734
        assert isinstance(output, IntermediateTensors)
735
        parallel_config = self.vllm_config.parallel_config
736
        assert (
737
            parallel_config.distributed_executor_backend != "external_launcher"
738
739
            and not get_pp_group().is_last_rank
        )
740

741
742
        # launch non-blocking send of intermediate tensors
        self._pp_send_work = get_pp_group().isend_tensor_dict(
743
744
745
746
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
747

748
        return None
749

750
    def take_draft_token_ids(self) -> DraftTokenIds | None:
751
752
        return self.model_runner.take_draft_token_ids()

753
754
755
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        # Check if profiling is enabled
        if self.profiler_config is None or self.profiler_config.profiler is None:
756
757
758
759
760
761
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
762

763
        if is_start:
764
765
766
767
768
769
770
771
772
773
774
775
776
            # Generate the trace name by combining prefix with comprehensive rank suffix
            from vllm.distributed.utils import get_worker_rank_suffix

            rank_suffix = get_worker_rank_suffix(global_rank=self.rank)

            # Build the full trace name
            if profile_prefix:
                trace_name = f"{profile_prefix}_{rank_suffix}"
            else:
                trace_name = rank_suffix

            # Create the profiler wrapper only on the first start call
            if self.profiler is None:
777
778
                profiler_type = self.profiler_config.profiler
                if profiler_type == "torch":
779
780
781
782
783
784
785
786
787
                    self.profiler = TorchProfilerWrapper(
                        self.profiler_config,
                        worker_name=trace_name,
                        local_rank=self.local_rank,
                        activities=["CPU", "CUDA"],
                    )
                    logger.debug(
                        "Starting torch profiler with trace name: %s", trace_name
                    )
788
                elif profiler_type == "cuda":
789
790
                    self.profiler = CudaProfilerWrapper(self.profiler_config)
                    logger.debug("Starting CUDA profiler")
791
                else:
792
793
794
795
796
797
798
799
                    # Config validation should prevent this code being reached
                    raise ValueError(
                        f"Invalid profiler value of {self.profiler_config.profiler}"
                    )

            # If profiler already initialized, restart profiling but keep
            # the original trace name from the first initialization.
            self.profiler.start()
800
        else:
801
802
803
            if self.profiler is None:
                logger.warning("Profiler was not started, nothing to stop.")
                return
804
805
            self.profiler.stop()

806
    def execute_dummy_batch(self) -> None:
807
        self.model_runner._dummy_run(1, uniform_decode=True)
808

809
810
811
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

812
813
814
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

815
    def list_loras(self) -> set[int]:
816
817
818
819
820
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

821
822
823
824
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

825
826
827
    def save_sharded_state(
        self,
        path: str,
828
829
        pattern: str | None = None,
        max_size: int | None = None,
830
    ) -> None:
831
        from vllm.model_executor.model_loader import ShardedStateLoader
832

833
834
835
836
837
838
839
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

840
841
842
    def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
        TensorizerLoader.save_model(
            self.get_model(),
843
            tensorizer_config=tensorizer_config,
844
            model_config=self.model_config,
845
        )
846

847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

910
    def shutdown(self) -> None:
911
912
913
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
914
915
        if self.profiler is not None:
            self.profiler.shutdown()
916

917
918
919
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

920
921
922
    def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
        return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)

923
924

def init_worker_distributed_environment(
925
    vllm_config: VllmConfig,
926
    rank: int,
927
    distributed_init_method: str | None = None,
928
    local_rank: int = -1,
929
    backend: str = "nccl",
930
931
) -> None:
    """Initialize the distributed environment."""
932
    attention_config = vllm_config.attention_config
933
    parallel_config = vllm_config.parallel_config
934
935
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

936
    init_batch_invariance(attention_config.backend)
937
    override_envs_for_eplb(parallel_config)
938
939
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

940
    init_method = distributed_init_method or "env://"
941
    init_distributed_environment(
942
        parallel_config.world_size, rank, init_method, local_rank, backend
943
    )
944

945
946
947
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
948
        parallel_config.prefill_context_parallel_size,
949
950
        parallel_config.decode_context_parallel_size,
    )
951
952
953
954

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)