gpu_worker.py 36 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
from vllm.distributed.parallel_state import (
29
    get_pcp_group,
30
31
32
    get_pp_group,
    get_tp_group,
)
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
from vllm.model_executor import set_random_seed
36
from vllm.model_executor.models.interfaces import is_mixture_of_experts
37
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
38
from vllm.platforms import current_platform
39
from vllm.profiler.gpu_profiler import CudaProfilerWrapper
40
from vllm.sequence import IntermediateTensors
41
from vllm.tasks import SupportedTask
42
43
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
44
from vllm.v1.core.sched.output import GrammarOutput
45
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
46
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
47
48
49
50
51
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
52
from vllm.v1.utils import report_usage_stats
53
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
54
from vllm.v1.worker.utils import is_residual_scattered_for_sp
55
from vllm.v1.worker.worker_base import WorkerBase
56
57
58
59

logger = init_logger(__name__)

if TYPE_CHECKING:
60
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
61
    from vllm.v1.core.sched.output import SchedulerOutput
62
63


64
class Worker(WorkerBase):
65
66
    def __init__(
        self,
67
        vllm_config: VllmConfig,
68
69
70
        local_rank: int,
        rank: int,
        distributed_init_method: str,
71
        is_driver_worker: bool = False,
72
    ):
73
74
75
76
77
78
79
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
80
81
82

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
83
            from vllm.utils.import_utils import init_cached_hf_modules
84

85
86
            init_cached_hf_modules()

87
88
89
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

90
91
92
93
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
94
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
95
96
97
98
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
99
100
101
102
103
104
105
106
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
107
108
109
110
111
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
112
113
114
115
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
116
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
117
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
118
119
                ),
            )
120
121
        elif envs.VLLM_TORCH_CUDA_PROFILE:
            self.profiler = CudaProfilerWrapper()
122
123
        else:
            self.profiler = None
124

125
    def sleep(self, level: int = 1) -> None:
126
127
        from vllm.device_allocator.cumem import CuMemAllocator

128
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
129
130
131
132
133

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
134
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
135
136
            }

137
        allocator = CuMemAllocator.get_instance()
138
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
139
140
141
142
143
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
144
145
146
147
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
148

149
    def wake_up(self, tags: list[str] | None = None) -> None:
150
151
        from vllm.device_allocator.cumem import CuMemAllocator

152
        allocator = CuMemAllocator.get_instance()
153
        allocator.wake_up(tags)
154

155
156
157
158
159
160
161
162
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

163
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
164
165
166
167
168
169
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
170
171
                    "Sleep mode can only be used for one instance per process."
                )
172
173
174
175
176
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

177
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
178
179
180
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

181
    def init_device(self):
182
183
184
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
185
186
187
188
189
190
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
191
                and self.vllm_config.parallel_config.nnodes_within_dp == 1
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
208
209
210
211
212
213
214
215
            visible_device_count = (
                torch.cuda.device_count() if torch.cuda.is_available() else 0
            )
            assert self.parallel_config.local_world_size <= visible_device_count, (
                f"local_world_size ({self.parallel_config.local_world_size}) must be "
                f"less than or equal to the number of visible devices "
                f"({visible_device_count})."
            )
216
            self.device = torch.device(f"cuda:{self.local_rank}")
217
            current_platform.set_device(self.device)
218

219
            current_platform.check_if_supports_dtype(self.model_config.dtype)
220
221
222
223
224

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
225
226
227
228
229
230
231
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
232
233
234
235
236

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
237
238
            gc.collect()
            torch.cuda.empty_cache()
239
240
241

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
242
243
244
245
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
246
            if self.init_snapshot.free_memory < self.requested_memory:
247
248
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
249
250
251
252
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
253
                    f"({self.cache_config.gpu_memory_utilization}, "
254
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
255
256
                    f"utilization or reduce GPU memory used by other processes."
                )
257
        else:
258
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
259

260
        # Construct the model runner
261
        self.model_runner: GPUModelRunner = GPUModelRunner(
262
263
            self.vllm_config, self.device
        )
264

265
266
267
268
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

269
270
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
271
    def load_model(self) -> None:
272
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
273
        with self._maybe_get_memory_pool_context(tag="weights"):
274
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
275

276
277
278
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

279
    def reload_weights(self) -> None:
280
        self.model_runner.reload_weights()
281

282
    @torch.inference_mode()
283
    def determine_available_memory(self) -> int:
284
        """Profiles the peak memory usage of the model to determine how much
285
        memory can be used for KV cache without OOMs.
286
287

        The engine will first conduct a profiling of the existing memory usage.
288
        Then, it calculates the free memory that can be used for KV cache in
289
        bytes.
290

291
292
293
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
294
        """
295
296
297
298
299
300
301
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
302
303
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
304
                "KV Cache as specified by kv_cache_memory_bytes config and "
305
                "skipped memory profiling. This does not respect the "
306
307
308
309
310
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
311
312
                "correspondingly."
            )
313
314
315
            logger.info(msg)
            return kv_cache_memory_bytes

316
        torch.cuda.empty_cache()
317
        torch.cuda.reset_peak_memory_stats()
318
319
320

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
321
        with memory_profiling(
322
323
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
324
        ) as profile_result:
325
            self.model_runner.profile_run()
326

327
328
329
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

330
        free_gpu_memory = profile_result.after_profile.free_memory
331
332
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
333
        assert self.init_snapshot.free_memory > free_gpu_memory, (
334
            "Error in memory profiling. "
335
336
337
338
339
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
340
341
342
343
344
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
345

346
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
347
        logger.debug(
348
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
349
350
351
352
353
354
355
356
357
358
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
359
        logger.debug(profile_result)
360
        logger.info_once(
361
362
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
363
            scope="local",
364
        )
365
        gc.collect()
366

367
        return int(self.available_kv_cache_memory_bytes)
368

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

384
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
385
386
        return self.model_runner.get_kv_cache_spec()

387
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
388
        """Allocate GPU KV cache with the specified kv_cache_config."""
389

390
391
392
393
394
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
395
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
396

397
        if self.vllm_config.model_config.enable_sleep_mode:
398
399
            from vllm.device_allocator.cumem import CuMemAllocator

400
401
402
403
404
405
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
406
407

    def compile_or_warm_up_model(self) -> None:
408
409
410
411
412
413
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
414
415
416
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
417
            ]
418
        # We skip EPLB here since we don't want to record dummy metrics
419
420
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
421
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
422
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
423

424
425
426
427
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

428
        cuda_graph_memory_bytes = 0
429
        if not self.model_config.enforce_eager:
430
431
            cuda_graph_memory_bytes = self.model_runner.capture_model()

432
433
434
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
435
436
437
438
439
440
441
442
443
444
445
446
447
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
448
449
450
451
452
453
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
454
            kv_cache_memory_bytes_to_gpu_limit = (
455
456
457
458
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
459
            kv_cache_memory_bytes_to_requested_limit = (
460
461
462
463
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
464
465
466
467
468
469
470
471
472
473
474
475
476
477

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
478
479
480
481
482
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
483
                f"utilize gpu memory. Current kv cache memory in use is "
484
485
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
486

487
            logger.debug(msg)
488
489
490
491
492
493

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
494
        if get_pp_group().is_last_rank:
495
496
497
498
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
499

500
            # We skip EPLB here since we don't want to record dummy metrics
501
502
503
504
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
505
506
507
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
508
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
509

510
511
512
513
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

514
515
516
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

517
518
519
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

520
521
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
522

523
524
525
526
527
528
529
530
531
532
533
534
535
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

        return torch.profiler.record_function(
            f"execute_new_{num_new}_cached_{num_cached}"
        )

536
537
    @torch.inference_mode()
    def sample_tokens(
538
        self, grammar_output: "GrammarOutput | None"
539
540
541
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

542
543
    @torch.inference_mode()
    def execute_model(
544
545
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
546
        intermediate_tensors = None
547
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
548
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
549
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
550
        all_gather_tensors = {
551
552
553
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
554
        }
555
        if forward_pass and not get_pp_group().is_first_rank:
556
557
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
558
                    all_gather_group=get_tp_group(),
559
560
561
                    all_gather_tensors=all_gather_tensors,
                )
            )
562

563
564
565
566
567
568
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
569

570
        assert isinstance(output, IntermediateTensors)
571
        parallel_config = self.vllm_config.parallel_config
572
        assert (
573
            parallel_config.distributed_executor_backend != "external_launcher"
574
575
            and not get_pp_group().is_last_rank
        )
576

577
578
579
580
581
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
582

583
        return None
584

585
    def take_draft_token_ids(self) -> DraftTokenIds | None:
586
587
        return self.model_runner.take_draft_token_ids()

588
    def profile(self, is_start: bool = True):
589
590
591
592
593
594
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
595
596
597
598
599
600
601
602
603
604
605
606
607
            if isinstance(self.profiler, torch.profiler.profile):
                rank = self.local_rank
                profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
                profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
                sort_key = "self_cuda_time_total"
                table = self.profiler.key_averages().table(sort_by=sort_key)

                with open(profiler_out_file, "w") as f:
                    print(table, file=f)

                # only print profiler results on rank 0
                if rank == 0:
                    print(table)
608

609
    def execute_dummy_batch(self) -> None:
610
        self.model_runner._dummy_run(1, uniform_decode=True)
611

612
613
614
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

615
616
617
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

618
    def list_loras(self) -> set[int]:
619
620
621
622
623
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

624
625
626
627
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

628
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
629
        from vllm.distributed.parallel_state import get_ep_group
630

631
        if get_ep_group().rank == 0:
632
633
634
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
635
636
637
638
639
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
640
641
642
643
644
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
645
646
647
648
649
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
650
651
652
        self,
        old_ep_size: int,
        new_ep_size: int,
653
        global_expert_loads: list[torch.Tensor] | None,
654
    ) -> None:
655
        from vllm.distributed.parallel_state import get_ep_group
656

657
        if get_ep_group().rank == 0:
658
659
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
660
661
662
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
663
            global_expert_loads=global_expert_loads,
664
665
            rank_mapping=rank_mapping,
        )
666
667
668
669
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
670
671
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
672
673
674
675
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
676
677
678
679
680
681
682
683
684
685
686
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
687
                reconfig_request.new_data_parallel_rank_local
688
689
            )
        parallel_config.data_parallel_master_ip = (
690
            reconfig_request.new_data_parallel_master_ip
691
692
        )
        parallel_config.data_parallel_master_port = (
693
            reconfig_request.new_data_parallel_master_port
694
        )
695

696
697
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
698
    ) -> torch.Tensor | None:
699
700
701
702
703
704
705
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
706
707
708
709
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
710
711
712
713
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
714
715

        parallel_config = self.vllm_config.parallel_config
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
737
                    pcp_size_=get_pcp_group().world_size,
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

761
762
763
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
764
            new_physical_experts = (
765
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
766
            )
767
            parallel_config.eplb_config.num_redundant_experts = (
768
769
770
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
771
            global_expert_loads = None
772
        else:
773
774
775
776
777
778
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
779
780
781
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
782
783
            global_expert_loads = self.model_runner.eplb_state.rearrange(
                execute_shuffle=False
784
            )
785
            parallel_config.eplb_config.num_redundant_experts = (
786
                new_physical_experts - global_expert_loads[0].shape[1]
787
            )
788
        prepare_communication_buffer_for_model(self.model_runner.model)
789
790
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
791
792
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
793
794
            num_local_physical_experts=num_local_physical_experts,
        )
795
        return global_expert_loads
796
797

    def reinitialize_distributed(
798
799
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
800
801
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
802
803
804
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
805
806
807

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
808
809
810
811
812
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
813
814
815
816
817
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

818
819
820
821
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
822
823
824
825
826
827
828
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
829
830
831
832
833
834
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
835

836
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
837
838

        if new_ep_size > old_ep_size:
839
840
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
841

842
843
844
    def save_sharded_state(
        self,
        path: str,
845
846
        pattern: str | None = None,
        max_size: int | None = None,
847
    ) -> None:
848
        from vllm.model_executor.model_loader import ShardedStateLoader
849

850
851
852
853
854
855
856
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

857
858
859
860
861
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
862
863
            tensorizer_config=tensorizer_config,
        )
864

865
    def shutdown(self) -> None:
866
867
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
868

869
870

def init_worker_distributed_environment(
871
    vllm_config: VllmConfig,
872
    rank: int,
873
    distributed_init_method: str | None = None,
874
    local_rank: int = -1,
875
    backend: str = "nccl",
876
877
) -> None:
    """Initialize the distributed environment."""
878
    parallel_config = vllm_config.parallel_config
879
880
881
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
882
883
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

884
885
886
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
887

888
889
890
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
891
        parallel_config.prefill_context_parallel_size,
892
893
        parallel_config.decode_context_parallel_size,
    )
894
895
896
897

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)