"vllm/vscode:/vscode.git/clone" did not exist on "9c749713f6990a9f9d12e526d9bfc2669dfa8ee6"
gpu_worker.py 39.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10

11
import numpy as np
12
13
import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
26
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
27
    ensure_kv_transfer_shutdown,
28
29
30
    get_kv_transfer_group,
    has_kv_transfer_group,
)
31
from vllm.distributed.parallel_state import (
32
    get_pcp_group,
33
34
35
    get_pp_group,
    get_tp_group,
)
36
from vllm.logger import init_logger
37
from vllm.lora.request import LoRARequest
38
from vllm.model_executor.models.interfaces import is_mixture_of_experts
39
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
40
from vllm.platforms import current_platform
41
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
42
from vllm.sequence import IntermediateTensors
43
from vllm.tasks import SupportedTask
44
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
45
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
46
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
47
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
48
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
49
50
51
52
53
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
54
from vllm.v1.utils import compute_iteration_details, report_usage_stats
55
from vllm.v1.worker.utils import is_residual_scattered_for_sp
56
from vllm.v1.worker.worker_base import WorkerBase
57
from vllm.v1.worker.workspace import init_workspace_manager
58

59
60
from .utils import request_memory

61
62
63
logger = init_logger(__name__)

if TYPE_CHECKING:
64
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
65
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
66
67


68
class Worker(WorkerBase):
69
70
    def __init__(
        self,
71
        vllm_config: VllmConfig,
72
73
74
        local_rank: int,
        rank: int,
        distributed_init_method: str,
75
        is_driver_worker: bool = False,
76
    ):
77
78
79
80
81
82
83
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
84

85
86
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
87
        torch.set_float32_matmul_precision(precision)
88

89
90
91
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

92
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
93
        self.profiler: Any | None = None
94
95
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
96
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
97
            self.profiler = TorchProfilerWrapper(
98
99
100
101
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU", "CUDA"],
102
            )
103
104
        elif profiler_config.profiler == "cuda":
            self.profiler = CudaProfilerWrapper(profiler_config)
105
106
        else:
            self.profiler = None
107

Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

110
111
112
        if self.use_v2_model_runner:
            logger.info_once("Using V2 Model Runner", scope="global")

113
    def sleep(self, level: int = 1) -> None:
114
115
        from vllm.device_allocator.cumem import CuMemAllocator

116
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
117
118
119
120
121

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
122
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
123
124
            }

125
        allocator = CuMemAllocator.get_instance()
126
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
127
128
129
130
131
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
132
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
133
134
            format_gib(freed_bytes),
            format_gib(used_bytes),
135
        )
136

137
    def wake_up(self, tags: list[str] | None = None) -> None:
138
139
        from vllm.device_allocator.cumem import CuMemAllocator

140
        allocator = CuMemAllocator.get_instance()
141
        allocator.wake_up(tags)
142

143
144
145
146
147
148
149
150
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

151
152
153
154
155
156
157
158
159
160
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

161
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
162
163
164
165
166
167
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
168
169
                    "Sleep mode can only be used for one instance per process."
                )
170
            return allocator.use_memory_pool(tag=tag)
171
        else:
172
            return nullcontext()
173

174
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
175
176
177
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

178
    def init_device(self):
179
        if self.device_config.device_type == "cuda":
180
181
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
182
            parallel_config = self.parallel_config
183
            if (
184
185
186
187
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
188
189
190
191
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
192
                    dp_local_rank = self.parallel_config.data_parallel_index
193
194
195
196
197
198
199
200
201
202
203

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
204
205
206
207
208
209
210
211
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
212
            self.device = torch.device(f"cuda:{self.local_rank}")
213
            current_platform.set_device(self.device)
214

215
            current_platform.check_if_supports_dtype(self.model_config.dtype)
216
217
218
219
220

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
221
222
223
224
225
226
227
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
228
229
230
231
232

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
233
234
            gc.collect()
            torch.cuda.empty_cache()
235
236

            # take current memory snapshot
237
238
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
239
240
241
242
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
243
        else:
244
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
245

246
247
248
249
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

250
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254
255
256
257
258
259
260
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
261
262
263
264
265
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
266

267
268
269
270
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

271
272
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
273
    def load_model(self) -> None:
274
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
275
276
277
        with self._maybe_get_memory_pool_context(
            tag="weights"
        ) and set_current_vllm_config(self.vllm_config):
278
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
279

280
281
282
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

283
284
    def reload_weights(self, *args, **kwargs) -> None:
        self.model_runner.reload_weights(*args, **kwargs)
285

286
    @torch.inference_mode()
287
    def determine_available_memory(self) -> int:
288
        """Profiles the peak memory usage of the model to determine how much
289
        memory can be used for KV cache without OOMs.
290
291

        The engine will first conduct a profiling of the existing memory usage.
292
        Then, it calculates the free memory that can be used for KV cache in
293
        bytes.
294

295
296
297
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
298
        """
299
300
301
302
303
304
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
305
306
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
307
                "KV Cache as specified by kv_cache_memory_bytes config and "
308
                "skipped memory profiling. This does not respect the "
309
310
311
312
313
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
314
315
                "correspondingly."
            )
316
317
318
            logger.info(msg)
            return kv_cache_memory_bytes

319
320
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
321
        with memory_profiling(
322
323
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
324
        ) as profile_result:
325
            self.model_runner.profile_run()
326

327
328
329
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

330
        free_gpu_memory = profile_result.after_profile.free_memory
331
332
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
333
        assert self.init_snapshot.free_memory >= free_gpu_memory, (
334
            "Error in memory profiling. "
335
336
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
337
338
339
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
340
341
342
343
344
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
345

346
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
347
        logger.debug(
348
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
349
            format_gib(self.init_snapshot.free_memory),
350
            self.cache_config.gpu_memory_utilization,
351
            format_gib(self.requested_memory),
352
353
        )
        logger.debug(
354
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
355
356
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
357
        )
358
        logger.debug(profile_result)
359
        logger.info_once(
360
            "Available KV cache memory: %s GiB",
361
            format_gib(self.available_kv_cache_memory_bytes),
362
            scope="local",
363
        )
364

365
        return int(self.available_kv_cache_memory_bytes)
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

382
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
383
384
        return self.model_runner.get_kv_cache_spec()

385
386
387
388
389
390
391
392
393
394
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
395
            self.model_runner.update_max_model_len(max_model_len)
396
397
        logger.debug("Updated max_model_len to %d", max_model_len)

398
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
399
        """Allocate GPU KV cache with the specified kv_cache_config."""
400

401
402
403
404
405
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
406
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
407

408
        if self.vllm_config.model_config.enable_sleep_mode:
409
410
            from vllm.device_allocator.cumem import CuMemAllocator

411
            allocator = CuMemAllocator.get_instance()
412
413
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
414
415
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
416
417

    def compile_or_warm_up_model(self) -> None:
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

443
        # We skip EPLB here since we don't want to record dummy metrics
444
445
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
446
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
447
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
448

449
450
451
452
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

453
        cuda_graph_memory_bytes = 0
454
        if not self.model_config.enforce_eager:
455
456
            cuda_graph_memory_bytes = self.model_runner.capture_model()

457
458
459
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
460
461
462
463
464
465
466
467
468
469
470
471
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
472
473
474
475
476
477
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
478
            kv_cache_memory_bytes_to_gpu_limit = (
479
480
481
482
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
483
            kv_cache_memory_bytes_to_requested_limit = (
484
485
486
487
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
488
489
490

            msg = (
                f"Free memory on device "
491
492
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
493
494
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
495
496
497
498
499
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
500
501
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
502
                f"{kv_cache_memory_bytes_to_requested_limit}` "
503
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
504
505
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
506
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
507
                f"utilize gpu memory. Current kv cache memory in use is "
508
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
509
            )
510

511
            logger.debug(msg)
512
513
514
515
516
517

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
518
        if get_pp_group().is_last_rank:
519
520
521
522
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
523

524
            # We skip EPLB here since we don't want to record dummy metrics
525
526
527
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
528
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
529
            )
530
531
532
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
533
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
534

535
536
537
538
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

539
540
541
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

542
543
544
    def reset_encoder_cache(self) -> None:
        self.model_runner.reset_encoder_cache()

545
546
547
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

548
549
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
550

551
552
553
554
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

555
556
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
557
558
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
559
560
561
        if not self.profiler:
            return nullcontext()

562
563
        self.profiler.step()

564
565
566
567
568
569
570
571
572
573
574
575
576
577
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
578
        )
579
        return self.profiler.annotate_context_manager(annotation)
580

581
582
    @torch.inference_mode()
    def sample_tokens(
583
        self, grammar_output: "GrammarOutput | None"
584
585
586
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

587
588
    @torch.inference_mode()
    def execute_model(
589
        self, scheduler_output: "SchedulerOutput"
590
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
591
        intermediate_tensors = None
592
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
593
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
594
595
596
597
598
599
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
600
            and compilation_config.pass_config.enable_sp
601
602
603
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
604
            assert not self.use_v2_model_runner
605
606
607
608
609
610
611
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
612
            _, batch_desc, _, _, _ = (
613
614
615
616
617
618
619
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
620
            )
621
622
623
624
625
626
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

627
        if forward_pass and not get_pp_group().is_first_rank:
628
629
630
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
631
            )
632
633
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
634

635
636
637
638
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
639
640
641
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
642
                return output
643

644
        assert isinstance(output, IntermediateTensors)
645
        parallel_config = self.vllm_config.parallel_config
646
        assert (
647
            parallel_config.distributed_executor_backend != "external_launcher"
648
649
            and not get_pp_group().is_last_rank
        )
650

651
652
653
654
655
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
656

657
        return None
658

659
    def take_draft_token_ids(self) -> DraftTokenIds | None:
660
661
        return self.model_runner.take_draft_token_ids()

662
    def profile(self, is_start: bool = True):
663
        if self.profiler is None:
664
665
666
667
668
669
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
670
671
672
673
674
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

675
    def execute_dummy_batch(self) -> None:
676
        self.model_runner._dummy_run(1, uniform_decode=True)
677

678
679
680
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

681
682
683
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

684
    def list_loras(self) -> set[int]:
685
686
687
688
689
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

690
691
692
693
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

694
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
695
        from vllm.distributed.parallel_state import get_ep_group
696

697
        if get_ep_group().rank == 0:
698
699
700
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
701
702
703
704
705
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
706
707
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
708
            global_expert_loads=None,
709
710
            rank_mapping=rank_mapping,
        )
711
712
713
714
715
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
716
717
718
        self,
        old_ep_size: int,
        new_ep_size: int,
719
        global_expert_loads: list[torch.Tensor] | None,
720
    ) -> None:
721
        from vllm.distributed.parallel_state import get_ep_group
722

723
        if get_ep_group().rank == 0:
724
725
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
726
727
728
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
729
            global_expert_loads=global_expert_loads,
730
731
            rank_mapping=rank_mapping,
        )
732
733
734
735
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
736
737
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
738
739
740
741
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
742
743
744
745
746
747
748
749
750
751
752
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
753
                reconfig_request.new_data_parallel_rank_local
754
755
            )
        parallel_config.data_parallel_master_ip = (
756
            reconfig_request.new_data_parallel_master_ip
757
758
        )
        parallel_config.data_parallel_master_port = (
759
            reconfig_request.new_data_parallel_master_port
760
        )
761

762
763
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
764
    ) -> list[torch.Tensor] | None:
765
766
767
768
769
770
771
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
772
773
774
775
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
776
777
778
779
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
780
781

        parallel_config = self.vllm_config.parallel_config
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
803
                    pcp_size_=get_pcp_group().world_size,
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

827
828
829
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
830
            new_physical_experts = (
831
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
832
            )
833
            parallel_config.eplb_config.num_redundant_experts = (
834
                new_physical_experts
835
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
836
            )
837
            global_expert_loads = None
838
        else:
839
            num_local_physical_experts_tensor = torch.tensor(
840
841
842
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
843
844
845
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
846
            )
847
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
848
849
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
850
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
851
                execute_shuffle=False
852
            )
853
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
854
            parallel_config.eplb_config.num_redundant_experts = (
855
                new_physical_experts - global_expert_loads[0].shape[1]
856
            )
857
        prepare_communication_buffer_for_model(self.model_runner.model)
858
859
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
860
861
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
862
863
            num_local_physical_experts=num_local_physical_experts,
        )
864
        return global_expert_loads
865
866

    def reinitialize_distributed(
867
868
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
869
870
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
871
872
873
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
874
875
876

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
877
878
879
880
881
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
882
883
884
885
886
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

887
888
889
890
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
891
892
893
894
895
896
897
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
898
899
900
901
902
903
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
904

905
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
906
907

        if new_ep_size > old_ep_size:
908
909
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
910

911
912
913
    def save_sharded_state(
        self,
        path: str,
914
915
        pattern: str | None = None,
        max_size: int | None = None,
916
    ) -> None:
917
        from vllm.model_executor.model_loader import ShardedStateLoader
918

919
920
921
922
923
924
925
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

926
927
928
929
930
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
931
932
            tensorizer_config=tensorizer_config,
        )
933

934
    def shutdown(self) -> None:
935
936
937
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
938
939
        if self.profiler is not None:
            self.profiler.shutdown()
940

941
942

def init_worker_distributed_environment(
943
    vllm_config: VllmConfig,
944
    rank: int,
945
    distributed_init_method: str | None = None,
946
    local_rank: int = -1,
947
    backend: str = "nccl",
948
949
) -> None:
    """Initialize the distributed environment."""
950
    attention_config = vllm_config.attention_config
951
    parallel_config = vllm_config.parallel_config
952
953
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

954
    init_batch_invariance(attention_config.backend)
955
956
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

957
    init_method = distributed_init_method or "env://"
958
    init_distributed_environment(
959
        parallel_config.world_size, rank, init_method, local_rank, backend
960
    )
961

962
963
964
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
965
        parallel_config.prefill_context_parallel_size,
966
967
        parallel_config.decode_context_parallel_size,
    )
968
969
970
971

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)