"csrc/attention/attention_kernels_opt.cu" did not exist on "dacaf5a40056c40be4a84f0aec42278335f76dae"
gpu_worker.py 42.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10

11
import numpy as np
12
13
import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
26
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
27
    ensure_kv_transfer_shutdown,
28
29
30
    get_kv_transfer_group,
    has_kv_transfer_group,
)
31
from vllm.distributed.parallel_state import (
32
    get_pcp_group,
33
34
35
    get_pp_group,
    get_tp_group,
)
36
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
37
from vllm.logger import init_logger
38
from vllm.lora.request import LoRARequest
39
from vllm.model_executor.models.interfaces import is_mixture_of_experts
40
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
41
from vllm.platforms import current_platform
42
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
43
from vllm.sequence import IntermediateTensors
44
from vllm.tasks import SupportedTask
45
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
46
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
47
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
48
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
49
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
50
51
52
53
54
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
55
from vllm.v1.utils import compute_iteration_details, report_usage_stats
56
from vllm.v1.worker.utils import is_residual_scattered_for_sp
57
from vllm.v1.worker.worker_base import WorkerBase
58
from vllm.v1.worker.workspace import init_workspace_manager
59

60
61
from .utils import request_memory

62
63
64
logger = init_logger(__name__)

if TYPE_CHECKING:
65
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
66
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
67
68


69
class Worker(WorkerBase):
70
71
    def __init__(
        self,
72
        vllm_config: VllmConfig,
73
74
75
        local_rank: int,
        rank: int,
        distributed_init_method: str,
76
        is_driver_worker: bool = False,
77
    ):
78
79
80
81
82
83
84
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
85

86
87
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
88
        torch.set_float32_matmul_precision(precision)
89

90
91
92
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

93
94
95
96
97
98
99
100
101
102
        # Weight transfer engine (initialized on-demand)
        self.weight_transfer_engine = (
            WeightTransferEngineFactory.create_engine(
                self.vllm_config.weight_transfer_config,
                self.vllm_config.parallel_config,
            )
            if self.vllm_config.weight_transfer_config is not None
            else None
        )

103
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
104
        self.profiler: Any | None = None
105
106
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
107
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
108
            self.profiler = TorchProfilerWrapper(
109
110
111
112
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU", "CUDA"],
113
            )
114
115
        elif profiler_config.profiler == "cuda":
            self.profiler = CudaProfilerWrapper(profiler_config)
116
117
        else:
            self.profiler = None
118

Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

121
122
123
        if self.use_v2_model_runner:
            logger.info_once("Using V2 Model Runner", scope="global")

124
    def sleep(self, level: int = 1) -> None:
125
126
        from vllm.device_allocator.cumem import CuMemAllocator

127
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
128
129
130
131
132

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
133
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
134
135
            }

136
        allocator = CuMemAllocator.get_instance()
137
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
138
139
140
141
142
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
143
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
144
145
            format_gib(freed_bytes),
            format_gib(used_bytes),
146
        )
147

148
    def wake_up(self, tags: list[str] | None = None) -> None:
149
150
        from vllm.device_allocator.cumem import CuMemAllocator

151
        allocator = CuMemAllocator.get_instance()
152
        allocator.wake_up(tags)
153

154
155
156
157
158
159
160
161
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

162
163
164
165
166
167
168
169
170
171
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

172
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
173
174
175
176
177
178
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
179
180
                    "Sleep mode can only be used for one instance per process."
                )
181
            return allocator.use_memory_pool(tag=tag)
182
        else:
183
            return nullcontext()
184

185
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
186
187
188
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

189
    def init_device(self):
190
        if self.device_config.device_type == "cuda":
191
192
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
193
            parallel_config = self.parallel_config
194
            if (
195
196
197
198
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
199
200
201
202
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
203
                    dp_local_rank = self.parallel_config.data_parallel_index
204
205
206
207
208
209
210
211
212
213
214

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
215
216
217
218
219
220
221
222
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
223

224
            self.device = torch.device(f"cuda:{self.local_rank}")
225
            current_platform.set_device(self.device)
226

227
            current_platform.check_if_supports_dtype(self.model_config.dtype)
228
229
230
231
232

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
233
234
235
236
237
238
239
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
240
241
242
243
244

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
245
246
            gc.collect()
            torch.cuda.empty_cache()
247
248

            # take current memory snapshot
249
250
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
251
252
253
254
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
255
        else:
256
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
257

258
259
260
261
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

262
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
267
268
269
270
271
272
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
273
274
275
276
277
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
278

279
280
281
282
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

283
284
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
285
    def load_model(self) -> None:
286
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
287
288
289
        with self._maybe_get_memory_pool_context(
            tag="weights"
        ) and set_current_vllm_config(self.vllm_config):
290
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
291

292
293
294
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

295
296
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
297

298
    @torch.inference_mode()
299
    def determine_available_memory(self) -> int:
300
        """Profiles the peak memory usage of the model to determine how much
301
        memory can be used for KV cache without OOMs.
302
303

        The engine will first conduct a profiling of the existing memory usage.
304
        Then, it calculates the free memory that can be used for KV cache in
305
        bytes.
306

307
308
309
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
310
        """
311
312
313
314
315
316
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
317
318
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
319
                "KV Cache as specified by kv_cache_memory_bytes config and "
320
                "skipped memory profiling. This does not respect the "
321
322
323
324
325
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
326
327
                "correspondingly."
            )
328
329
330
            logger.info(msg)
            return kv_cache_memory_bytes

331
332
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
333
        with memory_profiling(
334
335
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
336
        ) as profile_result:
337
            self.model_runner.profile_run()
338

339
340
341
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

342
        free_gpu_memory = profile_result.after_profile.free_memory
343
344
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
345
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
346
            "Error in memory profiling. "
347
348
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
349
350
351
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
352
353
354
355
356
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
357

358
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
359
        logger.debug(
360
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
361
            format_gib(self.init_snapshot.free_memory),
362
            self.cache_config.gpu_memory_utilization,
363
            format_gib(self.requested_memory),
364
365
        )
        logger.debug(
366
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
367
368
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
369
        )
370
        logger.debug(profile_result)
371
        logger.info_once(
372
            "Available KV cache memory: %s GiB",
373
            format_gib(self.available_kv_cache_memory_bytes),
374
            scope="local",
375
        )
376

377
        return int(self.available_kv_cache_memory_bytes)
378

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

394
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
395
396
        return self.model_runner.get_kv_cache_spec()

397
398
399
400
401
402
403
404
405
406
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
407
            self.model_runner.update_max_model_len(max_model_len)
408
409
        logger.debug("Updated max_model_len to %d", max_model_len)

410
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
411
        """Allocate GPU KV cache with the specified kv_cache_config."""
412

413
414
415
416
417
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
418
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
419

420
        if self.vllm_config.model_config.enable_sleep_mode:
421
422
            from vllm.device_allocator.cumem import CuMemAllocator

423
            allocator = CuMemAllocator.get_instance()
424
425
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
426
427
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
428
429

    def compile_or_warm_up_model(self) -> None:
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

455
        # We skip EPLB here since we don't want to record dummy metrics
456
457
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
458
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
459
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
460

461
462
463
464
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

465
        cuda_graph_memory_bytes = 0
466
        if not self.model_config.enforce_eager:
467
468
            cuda_graph_memory_bytes = self.model_runner.capture_model()

469
470
471
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
472
473
474
475
476
477
478
479
480
481
482
483
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
484
485
486
487
488
489
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
490
            kv_cache_memory_bytes_to_gpu_limit = (
491
492
493
494
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
495
            kv_cache_memory_bytes_to_requested_limit = (
496
497
498
499
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
500
501
502

            msg = (
                f"Free memory on device "
503
504
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
505
506
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
507
508
509
510
511
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
512
513
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
514
                f"{kv_cache_memory_bytes_to_requested_limit}` "
515
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
516
517
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
518
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
519
                f"utilize gpu memory. Current kv cache memory in use is "
520
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
521
            )
522

523
            logger.debug(msg)
524
525
526
527
528
529

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
530
        if get_pp_group().is_last_rank:
531
532
533
534
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
535

536
            # We skip EPLB here since we don't want to record dummy metrics
537
538
539
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
540
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
541
            )
542
543
544
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
545
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
546

547
548
549
550
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

551
552
553
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

554
555
556
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

557
558
559
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

560
561
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
562

563
564
565
566
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

567
568
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
569
570
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
571
572
573
        if not self.profiler:
            return nullcontext()

574
575
        self.profiler.step()

576
577
578
579
580
581
582
583
584
585
586
587
588
589
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
590
        )
591
        return self.profiler.annotate_context_manager(annotation)
592

593
594
    @torch.inference_mode()
    def sample_tokens(
595
        self, grammar_output: "GrammarOutput | None"
596
597
598
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

599
600
    @torch.inference_mode()
    def execute_model(
601
        self, scheduler_output: "SchedulerOutput"
602
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
603
        intermediate_tensors = None
604
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
605
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
606
607
608
609
610
611
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
612
            and compilation_config.pass_config.enable_sp
613
614
615
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
616
            assert not self.use_v2_model_runner
617
618
619
620
621
622
623
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
624
            _, batch_desc, _, _, _ = (
625
626
627
628
629
630
631
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
632
            )
633
634
635
636
637
638
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

639
        if forward_pass and not get_pp_group().is_first_rank:
640
641
642
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
643
            )
644
645
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
646

647
648
649
650
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
651
652
653
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
654
                return output
655

656
        assert isinstance(output, IntermediateTensors)
657
        parallel_config = self.vllm_config.parallel_config
658
        assert (
659
            parallel_config.distributed_executor_backend != "external_launcher"
660
661
            and not get_pp_group().is_last_rank
        )
662

663
664
665
666
667
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
668

669
        return None
670

671
    def take_draft_token_ids(self) -> DraftTokenIds | None:
672
673
        return self.model_runner.take_draft_token_ids()

674
    def profile(self, is_start: bool = True):
675
        if self.profiler is None:
676
677
678
679
680
681
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
682
683
684
685
686
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

687
    def execute_dummy_batch(self) -> None:
688
        self.model_runner._dummy_run(1, uniform_decode=True)
689

690
691
692
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

693
694
695
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

696
    def list_loras(self) -> set[int]:
697
698
699
700
701
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

702
703
704
705
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

706
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
707
        from vllm.distributed.parallel_state import get_ep_group
708

709
        if get_ep_group().rank == 0:
710
711
712
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
713
714
715
716
717
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
718
719
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
720
            global_expert_loads=None,
721
722
            rank_mapping=rank_mapping,
        )
723
724
725
726
727
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
728
729
730
        self,
        old_ep_size: int,
        new_ep_size: int,
731
        global_expert_loads: list[torch.Tensor] | None,
732
    ) -> None:
733
        from vllm.distributed.parallel_state import get_ep_group
734

735
        if get_ep_group().rank == 0:
736
737
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
738
739
740
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
741
            global_expert_loads=global_expert_loads,
742
743
            rank_mapping=rank_mapping,
        )
744
745
746
747
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
748
749
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
750
751
752
753
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
754
755
756
757
758
759
760
761
762
763
764
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
765
                reconfig_request.new_data_parallel_rank_local
766
767
            )
        parallel_config.data_parallel_master_ip = (
768
            reconfig_request.new_data_parallel_master_ip
769
770
        )
        parallel_config.data_parallel_master_port = (
771
            reconfig_request.new_data_parallel_master_port
772
        )
773

774
775
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
776
    ) -> list[torch.Tensor] | None:
777
778
779
780
781
782
783
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
784
785
786
787
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
788
789
790
791
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
792
793

        parallel_config = self.vllm_config.parallel_config
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
815
                    pcp_size_=get_pcp_group().world_size,
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

839
840
841
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
842
            new_physical_experts = (
843
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
844
            )
845
            parallel_config.eplb_config.num_redundant_experts = (
846
                new_physical_experts
847
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
848
            )
849
            global_expert_loads = None
850
        else:
851
            num_local_physical_experts_tensor = torch.tensor(
852
853
854
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
855
856
857
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
858
            )
859
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
860
861
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
862
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
863
                execute_shuffle=False
864
            )
865
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
866
            parallel_config.eplb_config.num_redundant_experts = (
867
                new_physical_experts - global_expert_loads[0].shape[1]
868
            )
869
        prepare_communication_buffer_for_model(self.model_runner.model)
870
871
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
872
873
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
874
875
            num_local_physical_experts=num_local_physical_experts,
        )
876
        return global_expert_loads
877
878

    def reinitialize_distributed(
879
880
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
881
882
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
883
884
885
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
886
887
888

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
889
890
891
892
893
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
894
895
896
897
898
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

899
900
901
902
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
903
904
905
906
907
908
909
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
910
911
912
913
914
915
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
916

917
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
918
919

        if new_ep_size > old_ep_size:
920
921
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
922

923
924
925
    def save_sharded_state(
        self,
        path: str,
926
927
        pattern: str | None = None,
        max_size: int | None = None,
928
    ) -> None:
929
        from vllm.model_executor.model_loader import ShardedStateLoader
930

931
932
933
934
935
936
937
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

938
939
940
941
942
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
943
944
            tensorizer_config=tensorizer_config,
        )
945

946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    def init_weight_transfer_engine(self, init_info: dict) -> None:
        """
        Initialize weight transfer mechanism.
        For NCCL backend, this creates a process group with the trainer.

        Args:
            init_info: Dictionary containing backend-specific initialization info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )
        # Parse dict into backend-specific typed dataclass
        typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
        self.weight_transfer_engine.init_transfer_engine(typed_init_info)

    def update_weights(self, update_info: dict) -> None:
        """
        Batched weight update from the trainer.

        Args:
            update_info: Dictionary containing backend-specific update info
        """
        if self.weight_transfer_engine is None:
            raise RuntimeError(
                "Weight transfer not configured. "
                "Please set weight_transfer_config to enable weight transfer."
            )

        # Parse dict into backend-specific typed dataclass
        typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)

        model = self.model_runner.model

        if typed_update_info.is_checkpoint_format:
            from vllm.model_executor.model_loader.reload import (
                finalize_layerwise_reload,
                initialize_layerwise_reload,
            )

            # Use layerwise reload pattern for checkpoint format weights
            with torch.device(self.device):
                initialize_layerwise_reload(model)
                self.weight_transfer_engine.receive_weights(
                    typed_update_info,
                    load_weights=model.load_weights,
                )
                finalize_layerwise_reload(model, self.model_config)
        else:
            # Weights are already in kernel format, copy directly
            def load_weights_direct(
                weights: list[tuple[str, torch.Tensor]],
            ) -> None:
                for name, weight in weights:
                    param = model.get_parameter(name)
                    param.copy_(weight)

            self.weight_transfer_engine.receive_weights(
                typed_update_info,
                load_weights=load_weights_direct,
            )

1009
    def shutdown(self) -> None:
1010
1011
1012
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
1013
1014
        if self.profiler is not None:
            self.profiler.shutdown()
1015

1016
1017
1018
        if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
            weight_transfer_engine.shutdown()

1019
1020

def init_worker_distributed_environment(
1021
    vllm_config: VllmConfig,
1022
    rank: int,
1023
    distributed_init_method: str | None = None,
1024
    local_rank: int = -1,
1025
    backend: str = "nccl",
1026
1027
) -> None:
    """Initialize the distributed environment."""
1028
    attention_config = vllm_config.attention_config
1029
    parallel_config = vllm_config.parallel_config
1030
1031
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

1032
    init_batch_invariance(attention_config.backend)
1033
1034
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

1035
    init_method = distributed_init_method or "env://"
1036
    init_distributed_environment(
1037
        parallel_config.world_size, rank, init_method, local_rank, backend
1038
    )
1039

1040
1041
1042
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
1043
        parallel_config.prefill_context_parallel_size,
1044
1045
        parallel_config.decode_context_parallel_size,
    )
1046
1047
1048
1049

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)