"tests/kernels/untest_ggml.py" did not exist on "300da09177477d0a4d2b55790addefd971f52ae0"
gpu_worker.py 47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from collections.abc import Callable
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any, cast
11

12
import numpy as np
13
14
import torch
import torch.distributed
15
import torch.nn as nn
16

17
import vllm.envs as envs
18
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
19
from vllm.config.compilation import CompilationMode
20
21
22
23
24
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
25
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
26
from vllm.distributed.eplb.eplb_utils import override_envs_for_eplb
27
28
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
29
    ensure_kv_transfer_shutdown,
30
31
32
    get_kv_transfer_group,
    has_kv_transfer_group,
)
33
from vllm.distributed.parallel_state import (
34
    Handle,
35
    get_pcp_group,
36
37
38
    get_pp_group,
    get_tp_group,
)
39
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
40
from vllm.logger import init_logger
41
from vllm.lora.request import LoRARequest
42
from vllm.model_executor.models.interfaces import is_mixture_of_experts
43
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
44
from vllm.platforms import current_platform
45
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
46
from vllm.sequence import IntermediateTensors
47
from vllm.tasks import SupportedTask
48
from vllm.tracing import instrument
49
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
50
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
51
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
52
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
53
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
54
55
56
57
58
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
59
from vllm.v1.utils import compute_iteration_details, report_usage_stats
60
from vllm.v1.worker.utils import is_residual_scattered_for_sp
61
from vllm.v1.worker.worker_base import WorkerBase
62
from vllm.v1.worker.workspace import init_workspace_manager
63

64
from .gpu.warmup import warmup_kernels
65
66
from .utils import request_memory

67
68
69
logger = init_logger(__name__)

if TYPE_CHECKING:
70
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
71
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
72
73


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class AsyncIntermediateTensors(IntermediateTensors):
    """IntermediateTensors with lazy comm synchronization"""

    def __init__(
        self,
        tensors: dict[str, torch.Tensor],
        comm_handles: list[Handle] | None = None,
        comm_postprocess: list[Callable[[], None]] | None = None,
    ) -> None:
        super().__init__(tensors)
        self._comm_handles = comm_handles
        self._comm_postprocess = comm_postprocess
        self._comm_waited = False

    def wait_for_comm(self) -> None:
        if self._comm_waited:
            return
        if self._comm_handles:
            for handle in self._comm_handles:
                handle.wait()
        if self._comm_postprocess:
            for fn in self._comm_postprocess:
                fn()
        self._comm_waited = True

    def __getattribute__(self, name: str):
        # ensure `.tensors` is ready before use
        if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
            object.__getattribute__(self, "wait_for_comm")()
        return object.__getattribute__(self, name)


106
class Worker(WorkerBase):
107
108
    def __init__(
        self,
109
        vllm_config: VllmConfig,
110
111
112
        local_rank: int,
        rank: int,
        distributed_init_method: str,
113
        is_driver_worker: bool = False,
114
    ):
115
116
117
118
119
120
121
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
122

123
124
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
125
        torch.set_float32_matmul_precision(precision)
126

127
128
129
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

130
131
132
133
134
135
136
137
138
139
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

140
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
141
142
        # Profiler wrapper is created lazily in profile() when start is called,
        # so we have all the information needed for proper trace naming.
143
        self.profiler: Any | None = None
144
145
146
147
148
        self.profiler_config = vllm_config.profiler_config

        # Only validate profiler config is valid, don't instantiate yet
        if self.profiler_config.profiler not in ("torch", "cuda", None):
            raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
149

Woosuk Kwon's avatar
Woosuk Kwon committed
150
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
151
152
        # pending non-blocking PP send work from the previous iteration
        self._pp_send_work: list[Handle] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
    def sleep(self, level: int = 1) -> None:
155
156
        from vllm.device_allocator.cumem import CuMemAllocator

157
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
158
159
160
161
162

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
163
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
164
165
            }

166
        allocator = CuMemAllocator.get_instance()
167
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
168
169
170
171
172
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
173
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
174
175
            format_gib(freed_bytes),
            format_gib(used_bytes),
176
        )
177

178
    def wake_up(self, tags: list[str] | None = None) -> None:
179
180
        from vllm.device_allocator.cumem import CuMemAllocator

181
        allocator = CuMemAllocator.get_instance()
182
        allocator.wake_up(tags)
183

184
185
186
187
188
189
190
191
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

192
193
194
195
196
197
198
199
200
201
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

202
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
203
204
205
206
207
208
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
209
210
                    "Sleep mode can only be used for one instance per process."
                )
211
            return allocator.use_memory_pool(tag=tag)
212
        else:
213
            return nullcontext()
214

215
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
216
217
218
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

219
    @instrument(span_name="Init device")
220
    def init_device(self):
221
        if self.device_config.device_type == "cuda":
222
223
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
224
            parallel_config = self.parallel_config
225
            if (
226
227
228
229
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
230
231
232
233
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
234
                    dp_local_rank = self.parallel_config.data_parallel_index
235
236
237
238
239
240
241
242
243
244
245

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
246
247
248
249
250
251
252
253
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
254

255
            self.device = torch.device(f"cuda:{self.local_rank}")
256
            current_platform.set_device(self.device)
257

258
            current_platform.check_if_supports_dtype(self.model_config.dtype)
259
260
261
262
263

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
264
265
266
267
268
269
270
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
271

272
273
274
            if self.use_v2_model_runner:
                logger.info_once("Using V2 Model Runner", scope="local")

275
276
277
278
            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
279
280
            gc.collect()
            torch.cuda.empty_cache()
281
282

            # take current memory snapshot
283
284
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
285
286
287
288
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
289
        else:
290
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
291

292
293
294
295
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

296
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
299
300
301
302
303
304
305
306
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
307
308
309
310
311
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
312

313
314
315
316
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

317
318
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
319
    def load_model(self) -> None:
320
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
321
322
323
324
        with (
            self._maybe_get_memory_pool_context(tag="weights"),
            set_current_vllm_config(self.vllm_config),
        ):
325
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
326

327
328
329
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

330
331
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
332

333
    @torch.inference_mode()
334
    def determine_available_memory(self) -> int:
335
        """Profiles the peak memory usage of the model to determine how much
336
        memory can be used for KV cache without OOMs.
337
338

        The engine will first conduct a profiling of the existing memory usage.
339
        Then, it calculates the free memory that can be used for KV cache in
340
        bytes.
341

342
343
344
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
345
        """
346
347
348
349
350
351
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
352
353
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
354
                "KV Cache as specified by kv_cache_memory_bytes config and "
355
                "skipped memory profiling. This does not respect the "
356
357
358
359
360
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
361
362
                "correspondingly."
            )
363
364
365
            logger.info(msg)
            return kv_cache_memory_bytes

366
367
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
368
        with memory_profiling(
369
370
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
371
        ) as profile_result:
372
            self.model_runner.profile_run()
373

374
375
376
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

377
        free_gpu_memory = profile_result.after_profile.free_memory
378
379
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
380
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
381
            "Error in memory profiling. "
382
383
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
384
385
386
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
387
388
389
390
391
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
392

393
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
394
        logger.debug(
395
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
396
            format_gib(self.init_snapshot.free_memory),
397
            self.cache_config.gpu_memory_utilization,
398
            format_gib(self.requested_memory),
399
400
        )
        logger.debug(
401
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
402
403
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
404
        )
405
        logger.debug(profile_result)
406
        logger.info_once(
407
            "Available KV cache memory: %s GiB",
408
            format_gib(self.available_kv_cache_memory_bytes),
409
            scope="local",
410
        )
411

412
        return int(self.available_kv_cache_memory_bytes)
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

429
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
430
431
        return self.model_runner.get_kv_cache_spec()

432
433
434
435
436
437
438
439
440
441
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
442
            self.model_runner.update_max_model_len(max_model_len)
443
444
        logger.debug("Updated max_model_len to %d", max_model_len)

445
    @instrument(span_name="Allocate KV cache")
446
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
447
        """Allocate GPU KV cache with the specified kv_cache_config."""
448

449
450
451
452
453
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
454
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
455

456
        if self.vllm_config.model_config.enable_sleep_mode:
457
458
            from vllm.device_allocator.cumem import CuMemAllocator

459
            allocator = CuMemAllocator.get_instance()
460
461
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
462
463
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
464

465
    @instrument(span_name="Warmup (GPU)")
466
    def compile_or_warm_up_model(self) -> None:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

492
        # We skip EPLB here since we don't want to record dummy metrics
493
494
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
495
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
496
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
497

498
499
500
501
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

502
        cuda_graph_memory_bytes = 0
503
        if not self.model_config.enforce_eager:
504
505
            cuda_graph_memory_bytes = self.model_runner.capture_model()

506
507
508
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
509
510
511
512
513
514
515
516
517
518
519
520
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
521
522
523
524
525
526
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
527
            kv_cache_memory_bytes_to_gpu_limit = (
528
529
530
531
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
532
            kv_cache_memory_bytes_to_requested_limit = (
533
534
535
536
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
537
538
539

            msg = (
                f"Free memory on device "
540
541
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
542
543
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
544
545
546
547
548
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
549
550
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
551
                f"{kv_cache_memory_bytes_to_requested_limit}` "
552
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
553
554
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
555
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
556
                f"utilize gpu memory. Current kv cache memory in use is "
557
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
558
            )
559

560
            logger.debug(msg)
561

562
563
564
565
566
567
568
569
570
        if self.use_v2_model_runner:
            # V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
            warmup_kernels(self.model_runner)
        elif get_pp_group().is_last_rank:
            # V1: Warm up sampler and preallocate memory buffer for logits and other
            # sampling related tensors of max possible shape to avoid memory
            # fragmentation issue.
            # NOTE: This is called after `capture_model` on purpose to prevent
            # memory buffers from being cleared by `torch.cuda.empty_cache`.
571
572
573
574
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
575

576
            # We skip EPLB here since we don't want to record dummy metrics
577
578
579
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
580
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
581
            )
582
583
584
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
585
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
586

587
588
589
590
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

591
592
593
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

594
595
596
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

597
598
599
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

600
601
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
602

603
604
605
606
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

607
608
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
609
610
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
611
612
613
        if not self.profiler:
            return nullcontext()

614
615
        self.profiler.step()

616
617
618
619
620
621
622
623
624
625
626
627
628
629
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
630
        )
631
        return self.profiler.annotate_context_manager(annotation)
632

633
634
    @torch.inference_mode()
    def sample_tokens(
635
        self, grammar_output: "GrammarOutput | None"
636
637
638
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

639
640
    @torch.inference_mode()
    def execute_model(
641
        self, scheduler_output: "SchedulerOutput"
642
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
643
644
645
646
647
648
        # ensure any previous non-blocking PP sends are complete
        if self._pp_send_work:
            for handle in self._pp_send_work:
                handle.wait()
            self._pp_send_work = []

649
        intermediate_tensors = None
650
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
651
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
652
653
654
655
656
657
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
658
            and compilation_config.pass_config.enable_sp
659
660
661
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
662
            assert not self.use_v2_model_runner
663
664
665
666
667
668
669
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
670
            _, batch_desc, _, _, _ = (
671
672
673
674
675
676
677
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
678
            )
679
680
681
682
683
684
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

685
        if forward_pass and not get_pp_group().is_first_rank:
686
687
688
689
690
            tensor_dict, comm_handles, comm_postprocess = (
                get_pp_group().irecv_tensor_dict(
                    all_gather_group=get_tp_group(),
                    all_gather_tensors=all_gather_tensors,
                )
691
            )
692
            assert tensor_dict is not None
693
694
695
696
697
            intermediate_tensors = AsyncIntermediateTensors(
                tensor_dict,
                comm_handles=comm_handles,
                comm_postprocess=comm_postprocess,
            )
698

699
700
701
702
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
703
704
705
706
707
708
            if (
                self.use_v2_model_runner
                and self.model_runner.is_pooling_model
                and output is None
            ):
                output = self.model_runner.pool()  # type: ignore
709
710
711
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
712
                return output
713

714
        assert isinstance(output, IntermediateTensors)
715
        parallel_config = self.vllm_config.parallel_config
716
        assert (
717
            parallel_config.distributed_executor_backend != "external_launcher"
718
719
            and not get_pp_group().is_last_rank
        )
720

721
722
        # launch non-blocking send of intermediate tensors
        self._pp_send_work = get_pp_group().isend_tensor_dict(
723
724
725
726
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
727

728
        return None
729

730
    def take_draft_token_ids(self) -> DraftTokenIds | None:
731
732
        return self.model_runner.take_draft_token_ids()

733
734
735
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        # Check if profiling is enabled
        if self.profiler_config is None or self.profiler_config.profiler is None:
736
737
738
739
740
741
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
742

743
        if is_start:
744
745
746
747
748
749
750
751
752
753
754
755
756
            # Generate the trace name by combining prefix with comprehensive rank suffix
            from vllm.distributed.utils import get_worker_rank_suffix

            rank_suffix = get_worker_rank_suffix(global_rank=self.rank)

            # Build the full trace name
            if profile_prefix:
                trace_name = f"{profile_prefix}_{rank_suffix}"
            else:
                trace_name = rank_suffix

            # Create the profiler wrapper only on the first start call
            if self.profiler is None:
757
758
                profiler_type = self.profiler_config.profiler
                if profiler_type == "torch":
759
760
761
762
763
764
765
766
767
                    self.profiler = TorchProfilerWrapper(
                        self.profiler_config,
                        worker_name=trace_name,
                        local_rank=self.local_rank,
                        activities=["CPU", "CUDA"],
                    )
                    logger.debug(
                        "Starting torch profiler with trace name: %s", trace_name
                    )
768
                elif profiler_type == "cuda":
769
770
                    self.profiler = CudaProfilerWrapper(self.profiler_config)
                    logger.debug("Starting CUDA profiler")
771
772
773
                else:
                    logger.warning("Unrecognized profiler: %s", profiler_type)
                    return
774
775
776
777
778
                self.profiler.start()
            else:
                # Profiler already initialized. Restart profiling but keep
                # the original trace name from the first initialization.
                self.profiler.start()
779
        else:
780
781
782
            if self.profiler is None:
                logger.warning("Profiler was not started, nothing to stop.")
                return
783
784
            self.profiler.stop()

785
    def execute_dummy_batch(self) -> None:
786
        self.model_runner._dummy_run(1, uniform_decode=True)
787

788
789
790
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

791
792
793
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

794
    def list_loras(self) -> set[int]:
795
796
797
798
799
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

800
801
802
803
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

804
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
805
        from vllm.distributed.parallel_state import get_ep_group
806

807
        if get_ep_group().rank == 0:
808
809
810
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
811
812
813
814
815
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
816
817
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
818
            global_expert_loads=None,
819
820
            rank_mapping=rank_mapping,
        )
821
822
823
824
825
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
826
827
828
        self,
        old_ep_size: int,
        new_ep_size: int,
829
        global_expert_loads: list[torch.Tensor] | None,
830
    ) -> None:
831
        from vllm.distributed.parallel_state import get_ep_group
832

833
        if get_ep_group().rank == 0:
834
835
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
836
837
838
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
839
            global_expert_loads=global_expert_loads,
840
841
            rank_mapping=rank_mapping,
        )
842
843
844
845
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
846
847
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
848
849
850
851
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
852
853
854
855
856
857
858
859
860
861
862
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
863
                reconfig_request.new_data_parallel_rank_local
864
865
            )
        parallel_config.data_parallel_master_ip = (
866
            reconfig_request.new_data_parallel_master_ip
867
868
        )
        parallel_config.data_parallel_master_port = (
869
            reconfig_request.new_data_parallel_master_port
870
        )
871

872
873
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
874
    ) -> list[torch.Tensor] | None:
875
876
877
878
879
880
881
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
882
883
884
885
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
886
887
888
889
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
890
891

        parallel_config = self.vllm_config.parallel_config
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
911
912
913
                tp_size = get_tp_group().world_size
                is_sequence_parallel = parallel_config.use_sequence_parallel_moe
                sp_size = tp_size if is_sequence_parallel else 1
914
                module.moe_parallel_config = FusedMoEParallelConfig.make(
915
                    tp_size_=tp_size,
916
                    pcp_size_=get_pcp_group().world_size,
917
                    dp_size_=get_dp_group().world_size,
918
                    sp_size_=sp_size,
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

941
942
943
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
944
            new_physical_experts = (
945
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
946
            )
947
            parallel_config.eplb_config.num_redundant_experts = (
948
                new_physical_experts
949
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
950
            )
951
            global_expert_loads = None
952
        else:
953
            num_local_physical_experts_tensor = torch.tensor(
954
955
956
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
957
958
959
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
960
            )
961
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
962
963
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
964
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
965
                execute_shuffle=False
966
            )
967
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
968
            parallel_config.eplb_config.num_redundant_experts = (
969
                new_physical_experts - global_expert_loads[0].shape[1]
970
            )
971
        prepare_communication_buffer_for_model(self.model_runner.model)
972
973
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
974
975
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
976
977
            num_local_physical_experts=num_local_physical_experts,
        )
978
        return global_expert_loads
979
980

    def reinitialize_distributed(
981
982
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
983
984
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
985
986
987
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
988
989
990

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
991
992
993
994
995
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
996
997
998
999
1000
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

1001
1002
1003
1004
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1005
1006
1007
1008
1009
1010
1011
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
1012
1013
1014
1015
1016
1017
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
1018

1019
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
1020
1021

        if new_ep_size > old_ep_size:
1022
1023
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
1024

1025
1026
1027
    def save_sharded_state(
        self,
        path: str,
1028
1029
        pattern: str | None = None,
        max_size: int | None = None,
1030
    ) -> None:
1031
        from vllm.model_executor.model_loader import ShardedStateLoader
1032

1033
1034
1035
1036
1037
1038
1039
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

1040
1041
1042
1043
1044
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
1045
1046
            tensorizer_config=tensorizer_config,
        )
1047

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

1111
    def shutdown(self) -> None:
1112
1113
1114
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
1115
1116
        if self.profiler is not None:
            self.profiler.shutdown()
1117

1118
1119
1120
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

1121
1122

def init_worker_distributed_environment(
1123
    vllm_config: VllmConfig,
1124
    rank: int,
1125
    distributed_init_method: str | None = None,
1126
    local_rank: int = -1,
1127
    backend: str = "nccl",
1128
1129
) -> None:
    """Initialize the distributed environment."""
1130
    attention_config = vllm_config.attention_config
1131
    parallel_config = vllm_config.parallel_config
1132
1133
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

1134
    init_batch_invariance(attention_config.backend)
1135
    override_envs_for_eplb(parallel_config)
1136
1137
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

1138
    init_method = distributed_init_method or "env://"
1139
    init_distributed_environment(
1140
        parallel_config.world_size, rank, init_method, local_rank, backend
1141
    )
1142

1143
1144
1145
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
1146
        parallel_config.prefill_context_parallel_size,
1147
1148
        parallel_config.decode_context_parallel_size,
    )
1149
1150
1151
1152

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)