gpu_worker.py 39.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10

11
import numpy as np
12
13
import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
26
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
27
    ensure_kv_transfer_shutdown,
28
29
30
    get_kv_transfer_group,
    has_kv_transfer_group,
)
31
from vllm.distributed.parallel_state import (
32
    get_pcp_group,
33
34
35
    get_pp_group,
    get_tp_group,
)
36
from vllm.logger import init_logger
37
from vllm.lora.request import LoRARequest
38
from vllm.model_executor.models.interfaces import is_mixture_of_experts
39
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
40
from vllm.platforms import current_platform
41
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
42
from vllm.sequence import IntermediateTensors
43
from vllm.tasks import SupportedTask
44
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
45
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
46
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
47
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
48
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
49
50
51
52
53
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
54
from vllm.v1.utils import compute_iteration_details, report_usage_stats
55
from vllm.v1.worker.utils import is_residual_scattered_for_sp
56
from vllm.v1.worker.worker_base import WorkerBase
57
from vllm.v1.worker.workspace import init_workspace_manager
58

59
60
from .utils import request_memory

61
62
63
logger = init_logger(__name__)

if TYPE_CHECKING:
64
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
65
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
66
67


68
class Worker(WorkerBase):
69
70
    def __init__(
        self,
71
        vllm_config: VllmConfig,
72
73
74
        local_rank: int,
        rank: int,
        distributed_init_method: str,
75
        is_driver_worker: bool = False,
76
    ):
77
78
79
80
81
82
83
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
84

85
86
        # configure float32 matmul precision according to vLLM env.
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
87
        torch.set_float32_matmul_precision(precision)
88

89
90
91
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

92
        # Torch/CUDA profiler. Enabled and configured through profiler_config.
93
        self.profiler: Any | None = None
94
95
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
96
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
97
            self.profiler = TorchProfilerWrapper(
98
99
100
101
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU", "CUDA"],
102
            )
103
104
        elif profiler_config.profiler == "cuda":
            self.profiler = CudaProfilerWrapper(profiler_config)
105
106
        else:
            self.profiler = None
107

Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

110
    def sleep(self, level: int = 1) -> None:
111
112
        from vllm.device_allocator.cumem import CuMemAllocator

113
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
114
115
116
117
118

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
119
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
120
121
            }

122
        allocator = CuMemAllocator.get_instance()
123
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
124
125
126
127
128
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
129
            "Sleep mode freed %s GiB memory, %s GiB memory is still in use.",
130
131
            format_gib(freed_bytes),
            format_gib(used_bytes),
132
        )
133

134
    def wake_up(self, tags: list[str] | None = None) -> None:
135
136
        from vllm.device_allocator.cumem import CuMemAllocator

137
        allocator = CuMemAllocator.get_instance()
138
        allocator.wake_up(tags)
139

140
141
142
143
144
145
146
147
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

148
149
150
151
152
153
154
155
156
157
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

158
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
159
160
161
162
163
164
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
165
166
                    "Sleep mode can only be used for one instance per process."
                )
167
            return allocator.use_memory_pool(tag=tag)
168
        else:
169
            return nullcontext()
170

171
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
172
173
174
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

175
    def init_device(self):
176
        if self.device_config.device_type == "cuda":
177
178
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
179
            parallel_config = self.parallel_config
180
            if (
181
182
183
184
                parallel_config.distributed_executor_backend
                not in ("ray", "external_launcher")
                and parallel_config.data_parallel_backend != "ray"
                and parallel_config.nnodes_within_dp == 1
185
186
187
188
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
189
                    dp_local_rank = self.parallel_config.data_parallel_index
190
191
192
193
194
195
196
197
198
199
200

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
201
202
203
204
205
206
207
208
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
209
            self.device = torch.device(f"cuda:{self.local_rank}")
210
            current_platform.set_device(self.device)
211

212
            current_platform.check_if_supports_dtype(self.model_config.dtype)
213
214
215
216
217

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
218
219
220
221
222
223
224
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
225
226
227
228
229

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
230
231
            gc.collect()
            torch.cuda.empty_cache()
232
233

            # take current memory snapshot
234
235
            self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
            self.requested_memory = request_memory(init_snapshot, self.cache_config)
236
237
238
239
            logger.debug("worker init memory snapshot: %r", self.init_snapshot)
            logger.debug(
                "worker requested memory: %sGiB", format_gib(self.requested_memory)
            )
240
        else:
241
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
242

243
244
245
246
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

247
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250
251
252
253
254
255
256
257
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
258
259
260
261
262
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
263

264
265
266
267
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

268
269
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
270
    def load_model(self) -> None:
271
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
272
273
274
        with self._maybe_get_memory_pool_context(
            tag="weights"
        ) and set_current_vllm_config(self.vllm_config):
275
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
276

277
278
279
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

280
    def reload_weights(self) -> None:
281
        self.model_runner.reload_weights()
282

283
    @torch.inference_mode()
284
    def determine_available_memory(self) -> int:
285
        """Profiles the peak memory usage of the model to determine how much
286
        memory can be used for KV cache without OOMs.
287
288

        The engine will first conduct a profiling of the existing memory usage.
289
        Then, it calculates the free memory that can be used for KV cache in
290
        bytes.
291

292
293
294
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
295
        """
296
297
298
299
300
301
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
302
303
                f"Initial free memory {format_gib(self.init_snapshot.free_memory)} "
                f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
304
                "KV Cache as specified by kv_cache_memory_bytes config and "
305
                "skipped memory profiling. This does not respect the "
306
307
308
309
310
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
311
312
                "correspondingly."
            )
313
314
315
            logger.info(msg)
            return kv_cache_memory_bytes

316
317
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
318
        with memory_profiling(
319
320
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
321
        ) as profile_result:
322
            self.model_runner.profile_run()
323

324
325
326
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

327
        free_gpu_memory = profile_result.after_profile.free_memory
328
329
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
330
        assert self.init_snapshot.free_memory > free_gpu_memory, (
331
            "Error in memory profiling. "
332
333
            f"Initial free memory {format_gib(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {format_gib(free_gpu_memory)} GiB. "
334
335
336
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
337
338
339
340
341
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
342

343
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
344
        logger.debug(
345
            "Initial free memory: %s GiB; Requested memory: %f (util), %s GiB",
346
            format_gib(self.init_snapshot.free_memory),
347
            self.cache_config.gpu_memory_utilization,
348
            format_gib(self.requested_memory),
349
350
        )
        logger.debug(
351
            "Free memory after profiling: %s GiB (total), %s GiB (within requested)",
352
353
            format_gib(free_gpu_memory),
            format_gib(free_gpu_memory - unrequested_memory),
354
        )
355
        logger.debug(profile_result)
356
        logger.info_once(
357
            "Available KV cache memory: %s GiB",
358
            format_gib(self.available_kv_cache_memory_bytes),
359
            scope="local",
360
        )
361

362
        return int(self.available_kv_cache_memory_bytes)
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

379
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
380
381
        return self.model_runner.get_kv_cache_spec()

382
383
384
385
386
387
388
389
390
391
    def update_max_model_len(self, max_model_len: int) -> None:
        """Update max_model_len after auto-fit to GPU memory.

        This is called when max_model_len=-1 is used and the engine
        automatically determines the maximum context length that fits
        in GPU memory. Workers need to update their cached max_model_len
        to match the engine's decision.
        """
        self.model_config.max_model_len = max_model_len
        if self.model_runner is not None:
392
            self.model_runner.update_max_model_len(max_model_len)
393
394
        logger.debug("Updated max_model_len to %d", max_model_len)

395
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
396
        """Allocate GPU KV cache with the specified kv_cache_config."""
397

398
399
400
401
402
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
403
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
404

405
        if self.vllm_config.model_config.enable_sleep_mode:
406
407
            from vllm.device_allocator.cumem import CuMemAllocator

408
            allocator = CuMemAllocator.get_instance()
409
410
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
411
412
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
413
414

    def compile_or_warm_up_model(self) -> None:
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

440
        # We skip EPLB here since we don't want to record dummy metrics
441
442
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
443
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
444
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
445

446
447
448
449
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

450
        cuda_graph_memory_bytes = 0
451
        if not self.model_config.enforce_eager:
452
453
            cuda_graph_memory_bytes = self.model_runner.capture_model()

454
455
456
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
457
458
459
460
461
462
463
464
465
466
467
468
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
469
470
471
472
473
474
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
475
            kv_cache_memory_bytes_to_gpu_limit = (
476
477
478
479
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
480
            kv_cache_memory_bytes_to_requested_limit = (
481
482
483
484
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
485
486
487

            msg = (
                f"Free memory on device "
488
489
                f"({format_gib(self.init_snapshot.free_memory)}/"
                f"{format_gib(self.init_snapshot.total_memory)} GiB) on startup. "
490
491
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
492
493
494
495
496
                f"{format_gib(self.requested_memory)} GiB). "
                f"Actual usage is {format_gib(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {format_gib(self.peak_activation_memory)} GiB "
                f"for peak activation, {format_gib(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {format_gib(cuda_graph_memory_bytes)} "
497
498
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
499
                f"{kv_cache_memory_bytes_to_requested_limit}` "
500
                f"({format_gib(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
501
502
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
503
                f"({format_gib(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
504
                f"utilize gpu memory. Current kv cache memory in use is "
505
                f"{format_gib(self.available_kv_cache_memory_bytes)} GiB."
506
            )
507

508
            logger.debug(msg)
509
510
511
512
513
514

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
515
        if get_pp_group().is_last_rank:
516
517
518
519
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
520

521
            # We skip EPLB here since we don't want to record dummy metrics
522
523
524
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
525
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
526
            )
527
528
529
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
530
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
531

532
533
534
535
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

536
537
538
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

539
540
541
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

542
543
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
544

545
546
547
548
    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """Get encoder timing stats from model runner."""
        return self.model_runner.get_encoder_timing_stats()

549
550
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
551
552
        # context/generation request numbers in each iteration.
        # A context request is a request that has not yet generated any tokens
553
554
555
        if not self.profiler:
            return nullcontext()

556
557
        self.profiler.step()

558
559
560
561
562
563
564
565
566
567
568
569
570
571
        iteration_details = compute_iteration_details(scheduler_output)

        annotation = "".join(
            [
                "execute_context_",
                str(iteration_details.num_ctx_requests),
                "(",
                str(iteration_details.num_ctx_tokens),
                ")_generation_",
                str(iteration_details.num_generation_requests),
                "(",
                str(iteration_details.num_generation_tokens),
                ")",
            ]
572
        )
573
        return self.profiler.annotate_context_manager(annotation)
574

575
576
    @torch.inference_mode()
    def sample_tokens(
577
        self, grammar_output: "GrammarOutput | None"
578
579
580
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

581
582
    @torch.inference_mode()
    def execute_model(
583
        self, scheduler_output: "SchedulerOutput"
584
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
585
        intermediate_tensors = None
586
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
587
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
588
589
590
591
592
593
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
594
            and compilation_config.pass_config.enable_sp
595
596
597
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
598
            assert not self.use_v2_model_runner
599
600
601
602
603
604
605
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
606
            _, batch_desc, _, _, _ = (
607
608
609
610
611
612
613
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
614
            )
615
616
617
618
619
620
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

621
        if forward_pass and not get_pp_group().is_first_rank:
622
623
624
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
625
            )
626
627
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
628

629
630
631
632
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
633
634
635
            if isinstance(
                output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
            ):
636
                return output
637

638
        assert isinstance(output, IntermediateTensors)
639
        parallel_config = self.vllm_config.parallel_config
640
        assert (
641
            parallel_config.distributed_executor_backend != "external_launcher"
642
643
            and not get_pp_group().is_last_rank
        )
644

645
646
647
648
649
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
650

651
        return None
652

653
    def take_draft_token_ids(self) -> DraftTokenIds | None:
654
655
        return self.model_runner.take_draft_token_ids()

656
    def profile(self, is_start: bool = True):
657
        if self.profiler is None:
658
659
660
661
662
663
            raise RuntimeError(
                "Profiling is not enabled. Please set --profiler-config to enable "
                "profiling. Example: "
                "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
                "=YOUR_DIR_PATH_TO_DUMP_TRACE'"
            )
664
665
666
667
668
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

669
    def execute_dummy_batch(self) -> None:
670
        self.model_runner._dummy_run(1, uniform_decode=True)
671

672
673
674
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

675
676
677
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

678
    def list_loras(self) -> set[int]:
679
680
681
682
683
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

684
685
686
687
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

688
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
689
        from vllm.distributed.parallel_state import get_ep_group
690

691
        if get_ep_group().rank == 0:
692
693
694
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
695
696
697
698
699
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
700
701
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
702
            global_expert_loads=None,
703
704
            rank_mapping=rank_mapping,
        )
705
706
707
708
709
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
710
711
712
        self,
        old_ep_size: int,
        new_ep_size: int,
713
        global_expert_loads: list[torch.Tensor] | None,
714
    ) -> None:
715
        from vllm.distributed.parallel_state import get_ep_group
716

717
        if get_ep_group().rank == 0:
718
719
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
720
721
722
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
723
            global_expert_loads=global_expert_loads,
724
725
            rank_mapping=rank_mapping,
        )
726
727
728
729
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
730
731
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
732
733
734
735
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
736
737
738
739
740
741
742
743
744
745
746
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
747
                reconfig_request.new_data_parallel_rank_local
748
749
            )
        parallel_config.data_parallel_master_ip = (
750
            reconfig_request.new_data_parallel_master_ip
751
752
        )
        parallel_config.data_parallel_master_port = (
753
            reconfig_request.new_data_parallel_master_port
754
        )
755

756
757
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
758
    ) -> list[torch.Tensor] | None:
759
760
761
762
763
764
765
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
766
767
768
769
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
770
771
772
773
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
774
775

        parallel_config = self.vllm_config.parallel_config
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
797
                    pcp_size_=get_pcp_group().world_size,
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

821
822
823
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
824
            new_physical_experts = (
825
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
826
            )
827
            parallel_config.eplb_config.num_redundant_experts = (
828
                new_physical_experts
829
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
830
            )
831
            global_expert_loads = None
832
        else:
833
            num_local_physical_experts_tensor = torch.tensor(
834
835
836
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
837
838
839
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
840
            )
841
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
842
843
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
844
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
845
                execute_shuffle=False
846
            )
847
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
848
            parallel_config.eplb_config.num_redundant_experts = (
849
                new_physical_experts - global_expert_loads[0].shape[1]
850
            )
851
        prepare_communication_buffer_for_model(self.model_runner.model)
852
853
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
854
855
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
856
857
            num_local_physical_experts=num_local_physical_experts,
        )
858
        return global_expert_loads
859
860

    def reinitialize_distributed(
861
862
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
863
864
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
865
866
867
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
868
869
870

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
871
872
873
874
875
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
876
877
878
879
880
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

881
882
883
884
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
885
886
887
888
889
890
891
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
892
893
894
895
896
897
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
898

899
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
900
901

        if new_ep_size > old_ep_size:
902
903
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
904

905
906
907
    def save_sharded_state(
        self,
        path: str,
908
909
        pattern: str | None = None,
        max_size: int | None = None,
910
    ) -> None:
911
        from vllm.model_executor.model_loader import ShardedStateLoader
912

913
914
915
916
917
918
919
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

920
921
922
923
924
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
925
926
            tensorizer_config=tensorizer_config,
        )
927

928
    def shutdown(self) -> None:
929
930
931
        # has_kv_transfer_group can be None during interpreter shutdown.
        if ensure_kv_transfer_shutdown is not None:
            ensure_kv_transfer_shutdown()
932
933
        if self.profiler is not None:
            self.profiler.shutdown()
934

935
936

def init_worker_distributed_environment(
937
    vllm_config: VllmConfig,
938
    rank: int,
939
    distributed_init_method: str | None = None,
940
    local_rank: int = -1,
941
    backend: str = "nccl",
942
943
) -> None:
    """Initialize the distributed environment."""
944
    attention_config = vllm_config.attention_config
945
    parallel_config = vllm_config.parallel_config
946
947
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

948
    init_batch_invariance(attention_config.backend)
949
950
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

951
    init_method = distributed_init_method or "env://"
952
    init_distributed_environment(
953
        parallel_config.world_size, rank, init_method, local_rank, backend
954
    )
955

956
957
958
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
959
        parallel_config.prefill_context_parallel_size,
960
961
        parallel_config.decode_context_parallel_size,
    )
962
963
964
965

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)