gpu_worker.py 42.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10

11
import numpy as np
12
13
import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
26
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
27
    ensure_kv_transfer_shutdown,
28
29
30
    get_kv_transfer_group,
    has_kv_transfer_group,
)
31
from vllm.distributed.parallel_state import (
32
    get_pcp_group,
33
34
35
    get_pp_group,
    get_tp_group,
)
36
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
37
from vllm.logger import init_logger
38
from vllm.lora.request import LoRARequest
39
from vllm.model_executor.models.interfaces import is_mixture_of_experts
40
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
41
from vllm.platforms import current_platform
42
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
43
from vllm.sequence import IntermediateTensors
44
from vllm.tasks import SupportedTask
45
from vllm.tracing import instrument
46
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
47
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
48
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
49
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
50
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
51
52
53
54
55
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
56
from vllm.v1.utils import compute_iteration_details, report_usage_stats
57
from vllm.v1.worker.utils import is_residual_scattered_for_sp
58
from vllm.v1.worker.worker_base import WorkerBase
59
from vllm.v1.worker.workspace import init_workspace_manager
60

61
62
from .utils import request_memory

63
64
65
logger = init_logger(__name__)

if TYPE_CHECKING:
66
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
67
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
68
69


70
class Worker(WorkerBase):
71
72
    def __init__(
        self,
73
        vllm_config: VllmConfig,
74
75
76
        local_rank: int,
        rank: int,
        distributed_init_method: str,
77
        is_driver_worker: bool = False,
78
    ):
79
80
81
82
83
84
85
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
86

87
88
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
89
        torch.set_float32_matmul_precision(precision)
90

91
92
93
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

94
95
96
97
98
99
100
101
102
103
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

104
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
105
        self.profiler: Any | None = None
106
107
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
108
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
109
            self.profiler = TorchProfilerWrapper(
110
111
112
113
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU", "CUDA"],
114
            )
115
116
        elif profiler_config.profiler == "cuda":
            self.profiler = CudaProfilerWrapper(profiler_config)
117
118
        else:
            self.profiler = None
119

Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

122
    def sleep(self, level: int = 1) -> None:
123
124
        from vllm.device_allocator.cumem import CuMemAllocator

125
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
126
127
128
129
130

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
131
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
132
133
            }

134
        allocator = CuMemAllocator.get_instance()
135
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
136
137
138
139
140
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
141
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
142
143
            format_gib(freed_bytes),
            format_gib(used_bytes),
144
        )
145

146
    def wake_up(self, tags: list[str] | None = None) -> None:
147
148
        from vllm.device_allocator.cumem import CuMemAllocator

149
        allocator = CuMemAllocator.get_instance()
150
        allocator.wake_up(tags)
151

152
153
154
155
156
157
158
159
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

160
161
162
163
164
165
166
167
168
169
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

170
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
171
172
173
174
175
176
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
177
178
                    "Sleep mode can only be used for one instance per process."
                )
179
            return allocator.use_memory_pool(tag=tag)
180
        else:
181
            return nullcontext()
182

183
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
184
185
186
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

187
    @instrument(span_name="Init device")
188
    def init_device(self):
189
        if self.device_config.device_type == "cuda":
190
191
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
192
            parallel_config = self.parallel_config
193
            if (
194
195
196
197
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
198
199
200
201
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
202
                    dp_local_rank = self.parallel_config.data_parallel_index
203
204
205
206
207
208
209
210
211
212
213

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
214
215
216
217
218
219
220
221
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
222

223
            self.device = torch.device(f"cuda:{self.local_rank}")
224
            current_platform.set_device(self.device)
225

226
            current_platform.check_if_supports_dtype(self.model_config.dtype)
227
228
229
230
231

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
232
233
234
235
236
237
238
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
239

240
241
242
            if self.use_v2_model_runner:
                logger.info_once("Using V2 Model Runner", scope="local")

243
244
245
246
            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
247
248
            gc.collect()
            torch.cuda.empty_cache()
249
250

            # take current memory snapshot
251
252
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
253
254
255
256
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
257
        else:
258
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
259

260
261
262
263
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

264
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267
268
269
270
271
272
273
274
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
275
276
277
278
279
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
280

281
282
283
284
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

285
286
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
287
    def load_model(self) -> None:
288
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
289
290
291
292
        with (
            self._maybe_get_memory_pool_context(tag="weights"),
            set_current_vllm_config(self.vllm_config),
        ):
293
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
294

295
296
297
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

298
299
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
300

301
    @torch.inference_mode()
302
    def determine_available_memory(self) -> int:
303
        """Profiles the peak memory usage of the model to determine how much
304
        memory can be used for KV cache without OOMs.
305
306

        The engine will first conduct a profiling of the existing memory usage.
307
        Then, it calculates the free memory that can be used for KV cache in
308
        bytes.
309

310
311
312
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
313
        """
314
315
316
317
318
319
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
320
321
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
322
                "KV Cache as specified by kv_cache_memory_bytes config and "
323
                "skipped memory profiling. This does not respect the "
324
325
326
327
328
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
329
330
                "correspondingly."
            )
331
332
333
            logger.info(msg)
            return kv_cache_memory_bytes

334
335
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
336
        with memory_profiling(
337
338
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
339
        ) as profile_result:
340
            self.model_runner.profile_run()
341

342
343
344
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

345
        free_gpu_memory = profile_result.after_profile.free_memory
346
347
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
348
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
349
            "Error in memory profiling. "
350
351
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
352
353
354
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
355
356
357
358
359
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
360

361
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
362
        logger.debug(
363
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
364
            format_gib(self.init_snapshot.free_memory),
365
            self.cache_config.gpu_memory_utilization,
366
            format_gib(self.requested_memory),
367
368
        )
        logger.debug(
369
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
370
371
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
372
        )
373
        logger.debug(profile_result)
374
        logger.info_once(
375
            "Available KV cache memory: %s GiB",
376
            format_gib(self.available_kv_cache_memory_bytes),
377
            scope="local",
378
        )
379

380
        return int(self.available_kv_cache_memory_bytes)
381

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

397
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
398
399
        return self.model_runner.get_kv_cache_spec()

400
401
402
403
404
405
406
407
408
409
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
410
            self.model_runner.update_max_model_len(max_model_len)
411
412
        logger.debug("Updated max_model_len to %d", max_model_len)

413
    @instrument(span_name="Allocate KV cache")
414
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
415
        """Allocate GPU KV cache with the specified kv_cache_config."""
416

417
418
419
420
421
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
422
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
423

424
        if self.vllm_config.model_config.enable_sleep_mode:
425
426
            from vllm.device_allocator.cumem import CuMemAllocator

427
            allocator = CuMemAllocator.get_instance()
428
429
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
430
431
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
432

433
    @instrument(span_name="Warmup (GPU)")
434
    def compile_or_warm_up_model(self) -> None:
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

460
        # We skip EPLB here since we don't want to record dummy metrics
461
462
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
463
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
464
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
465

466
467
468
469
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

470
        cuda_graph_memory_bytes = 0
471
        if not self.model_config.enforce_eager:
472
473
            cuda_graph_memory_bytes = self.model_runner.capture_model()

474
475
476
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
477
478
479
480
481
482
483
484
485
486
487
488
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
489
490
491
492
493
494
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
495
            kv_cache_memory_bytes_to_gpu_limit = (
496
497
498
499
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
500
            kv_cache_memory_bytes_to_requested_limit = (
501
502
503
504
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
505
506
507

            msg = (
                f"Free memory on device "
508
509
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
510
511
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
512
513
514
515
516
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
517
518
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
519
                f"{kv_cache_memory_bytes_to_requested_limit}` "
520
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
521
522
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
523
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
524
                f"utilize gpu memory. Current kv cache memory in use is "
525
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
526
            )
527

528
            logger.debug(msg)
529
530
531
532
533
534

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
535
        if get_pp_group().is_last_rank:
536
537
538
539
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
540

541
            # We skip EPLB here since we don't want to record dummy metrics
542
543
544
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
545
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
546
            )
547
548
549
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
550
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
551

552
553
554
555
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

556
557
558
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

559
560
561
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

562
563
564
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

565
566
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
567

568
569
570
571
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

572
573
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
574
575
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
576
577
578
        if not self.profiler:
            return nullcontext()

579
580
        self.profiler.step()

581
582
583
584
585
586
587
588
589
590
591
592
593
594
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
595
        )
596
        return self.profiler.annotate_context_manager(annotation)
597

598
599
    @torch.inference_mode()
    def sample_tokens(
600
        self, grammar_output: "GrammarOutput | None"
601
602
603
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

604
605
    @torch.inference_mode()
    def execute_model(
606
        self, scheduler_output: "SchedulerOutput"
607
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
608
        intermediate_tensors = None
609
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
610
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
611
612
613
614
615
616
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
617
            and compilation_config.pass_config.enable_sp
618
619
620
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
621
            assert not self.use_v2_model_runner
622
623
624
625
626
627
628
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
629
            _, batch_desc, _, _, _ = (
630
631
632
633
634
635
636
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
637
            )
638
639
640
641
642
643
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

644
        if forward_pass and not get_pp_group().is_first_rank:
645
646
647
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
648
            )
649
650
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
651

652
653
654
655
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
656
657
658
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
659
                return output
660

661
        assert isinstance(output, IntermediateTensors)
662
        parallel_config = self.vllm_config.parallel_config
663
        assert (
664
            parallel_config.distributed_executor_backend != "external_launcher"
665
666
            and not get_pp_group().is_last_rank
        )
667

668
669
670
671
672
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
673

674
        return None
675

676
    def take_draft_token_ids(self) -> DraftTokenIds | None:
677
678
        return self.model_runner.take_draft_token_ids()

679
    def profile(self, is_start: bool = True):
680
        if self.profiler is None:
681
682
683
684
685
686
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
687
688
689
690
691
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

692
    def execute_dummy_batch(self) -> None:
693
        self.model_runner._dummy_run(1, uniform_decode=True)
694

695
696
697
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

698
699
700
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

701
    def list_loras(self) -> set[int]:
702
703
704
705
706
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

707
708
709
710
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

711
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
712
        from vllm.distributed.parallel_state import get_ep_group
713

714
        if get_ep_group().rank == 0:
715
716
717
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
718
719
720
721
722
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
723
724
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
725
            global_expert_loads=None,
726
727
            rank_mapping=rank_mapping,
        )
728
729
730
731
732
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
733
734
735
        self,
        old_ep_size: int,
        new_ep_size: int,
736
        global_expert_loads: list[torch.Tensor] | None,
737
    ) -> None:
738
        from vllm.distributed.parallel_state import get_ep_group
739

740
        if get_ep_group().rank == 0:
741
742
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
743
744
745
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
746
            global_expert_loads=global_expert_loads,
747
748
            rank_mapping=rank_mapping,
        )
749
750
751
752
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
753
754
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
755
756
757
758
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
759
760
761
762
763
764
765
766
767
768
769
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
770
                reconfig_request.new_data_parallel_rank_local
771
772
            )
        parallel_config.data_parallel_master_ip = (
773
            reconfig_request.new_data_parallel_master_ip
774
775
        )
        parallel_config.data_parallel_master_port = (
776
            reconfig_request.new_data_parallel_master_port
777
        )
778

779
780
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
781
    ) -> list[torch.Tensor] | None:
782
783
784
785
786
787
788
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
789
790
791
792
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
793
794
795
796
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
797
798

        parallel_config = self.vllm_config.parallel_config
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
820
                    pcp_size_=get_pcp_group().world_size,
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

844
845
846
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
847
            new_physical_experts = (
848
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
849
            )
850
            parallel_config.eplb_config.num_redundant_experts = (
851
                new_physical_experts
852
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
853
            )
854
            global_expert_loads = None
855
        else:
856
            num_local_physical_experts_tensor = torch.tensor(
857
858
859
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
860
861
862
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
863
            )
864
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
865
866
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
867
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
868
                execute_shuffle=False
869
            )
870
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
871
            parallel_config.eplb_config.num_redundant_experts = (
872
                new_physical_experts - global_expert_loads[0].shape[1]
873
            )
874
        prepare_communication_buffer_for_model(self.model_runner.model)
875
876
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
877
878
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
879
880
            num_local_physical_experts=num_local_physical_experts,
        )
881
        return global_expert_loads
882
883

    def reinitialize_distributed(
884
885
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
886
887
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
888
889
890
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
891
892
893

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
894
895
896
897
898
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
899
900
901
902
903
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

904
905
906
907
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
908
909
910
911
912
913
914
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
915
916
917
918
919
920
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
921

922
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
923
924

        if new_ep_size > old_ep_size:
925
926
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
927

928
929
930
    def save_sharded_state(
        self,
        path: str,
931
932
        pattern: str | None = None,
        max_size: int | None = None,
933
    ) -> None:
934
        from vllm.model_executor.model_loader import ShardedStateLoader
935

936
937
938
939
940
941
942
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

943
944
945
946
947
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
948
949
            tensorizer_config=tensorizer_config,
        )
950

951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

1014
    def shutdown(self) -> None:
1015
1016
1017
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
1018
1019
        if self.profiler is not None:
            self.profiler.shutdown()
1020

1021
1022
1023
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

1024
1025

def init_worker_distributed_environment(
1026
    vllm_config: VllmConfig,
1027
    rank: int,
1028
    distributed_init_method: str | None = None,
1029
    local_rank: int = -1,
1030
    backend: str = "nccl",
1031
1032
) -> None:
    """Initialize the distributed environment."""
1033
    attention_config = vllm_config.attention_config
1034
    parallel_config = vllm_config.parallel_config
1035
1036
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

1037
    init_batch_invariance(attention_config.backend)
1038
1039
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

1040
    init_method = distributed_init_method or "env://"
1041
    init_distributed_environment(
1042
        parallel_config.world_size, rank, init_method, local_rank, backend
1043
    )
1044

1045
1046
1047
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
1048
        parallel_config.prefill_context_parallel_size,
1049
1050
        parallel_config.decode_context_parallel_size,
    )
1051
1052
1053
1054

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)