"requirements/cpu.txt" did not exist on "9edec652e2adfee7c06483271f5df4b1be1bfddc"
gpu_worker.py 32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
23
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
from vllm.model_executor import set_random_seed
27
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
28
from vllm.platforms import current_platform
29
from vllm.sequence import IntermediateTensors
30
from vllm.tasks import SupportedTask
31
32
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
33
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
34
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
35
36
37
38
39
40
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
41
from vllm.v1.utils import report_usage_stats
42
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
43
from vllm.v1.worker.utils import is_residual_scattered_for_sp
44
from vllm.v1.worker.worker_base import WorkerBase
45
46
47
48

logger = init_logger(__name__)

if TYPE_CHECKING:
49
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
50
    from vllm.v1.core.sched.output import SchedulerOutput
51
52


53
class Worker(WorkerBase):
54
55
    def __init__(
        self,
56
        vllm_config: VllmConfig,
57
58
59
        local_rank: int,
        rank: int,
        distributed_init_method: str,
60
        is_driver_worker: bool = False,
61
    ):
62
63
64
65
66
67
68
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
69
70
71
72

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
73

74
75
            init_cached_hf_modules()

76
77
78
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

79
80
81
82
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
83
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
84
85
86
87
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
88
89
90
91
92
93
94
95
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
96
97
98
99
100
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
101
102
103
104
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
105
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
106
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
107
108
                ),
            )
109
110
        else:
            self.profiler = None
111

112
    def sleep(self, level: int = 1) -> None:
113
114
        from vllm.device_allocator.cumem import CuMemAllocator

115
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
116
117
118
119
120

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
121
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
122
123
            }

124
        allocator = CuMemAllocator.get_instance()
125
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
126
127
128
129
130
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
131
132
133
134
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
135

136
    def wake_up(self, tags: list[str] | None = None) -> None:
137
138
        from vllm.device_allocator.cumem import CuMemAllocator

139
        allocator = CuMemAllocator.get_instance()
140
        allocator.wake_up(tags)
141

142
143
144
145
146
147
148
149
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

150
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
151
152
153
154
155
156
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
157
158
                    "Sleep mode can only be used for one instance per process."
                )
159
160
161
162
163
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

164
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
165
166
167
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

168
    def init_device(self):
169
170
171
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.data_parallel_backend != "ray"
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank <= torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )

193
            self.device = torch.device(f"cuda:{self.local_rank}")
194
            current_platform.set_device(self.device)
195

196
            current_platform.check_if_supports_dtype(self.model_config.dtype)
197
198
199
200
201

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
202
203
204
205
206
207
208
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
209
210
211
212
213

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
214
215
            gc.collect()
            torch.cuda.empty_cache()
216
217
218

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
219
220
221
222
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
223
            if self.init_snapshot.free_memory < self.requested_memory:
224
225
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
226
227
228
229
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
230
                    f"({self.cache_config.gpu_memory_utilization}, "
231
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
232
233
                    f"utilization or reduce GPU memory used by other processes."
                )
234
        else:
235
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
236

237
        # Construct the model runner
238
        self.model_runner: GPUModelRunner = GPUModelRunner(
239
240
            self.vllm_config, self.device
        )
241

242
243
244
245
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

246
247
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
248
    def load_model(self) -> None:
249
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
250
        with self._maybe_get_memory_pool_context(tag="weights"):
251
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
252

253
254
255
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

256
    def reload_weights(self) -> None:
257
        self.model_runner.reload_weights()
258

259
    @torch.inference_mode()
260
    def determine_available_memory(self) -> int:
261
        """Profiles the peak memory usage of the model to determine how much
262
        memory can be used for KV cache without OOMs.
263
264

        The engine will first conduct a profiling of the existing memory usage.
265
        Then, it calculates the free memory that can be used for KV cache in
266
        bytes.
267

268
269
270
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
271
        """
272
273
274
275
276
277
278
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
279
280
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
281
                "KV Cache as specified by kv_cache_memory_bytes config and "
282
                "skipped memory profiling. This does not respect the "
283
284
285
286
287
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
288
289
                "correspondingly."
            )
290
291
292
            logger.info(msg)
            return kv_cache_memory_bytes

293
        torch.cuda.empty_cache()
294
        torch.cuda.reset_peak_memory_stats()
295
296
297

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
298
        with memory_profiling(
299
300
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
301
        ) as profile_result:
302
            self.model_runner.profile_run()
303

304
305
306
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

307
        free_gpu_memory = profile_result.after_profile.free_memory
308
309
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
310
        assert self.init_snapshot.free_memory > free_gpu_memory, (
311
            "Error in memory profiling. "
312
313
314
315
316
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
317
318
319
320
321
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
322

323
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
324
        logger.debug(
325
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
326
327
328
329
330
331
332
333
334
335
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
336
        logger.debug(profile_result)
337
338
339
340
        logger.info(
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
        )
341
        gc.collect()
342

343
        return int(self.available_kv_cache_memory_bytes)
344

345
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
346
347
        return self.model_runner.get_kv_cache_spec()

348
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
349
        """Allocate GPU KV cache with the specified kv_cache_config."""
350

351
        if self.vllm_config.model_config.enable_sleep_mode:
352
353
            from vllm.device_allocator.cumem import CuMemAllocator

354
355
356
357
358
359
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
360
361

    def compile_or_warm_up_model(self) -> None:
362
363
364
365
366
367
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
368
369
370
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
371
            ]
372
        # We skip EPLB here since we don't want to record dummy metrics
373
374
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
375
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
376
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
377

378
379
380
381
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

382
        cuda_graph_memory_bytes = 0
383
        if not self.model_config.enforce_eager:
384
385
            cuda_graph_memory_bytes = self.model_runner.capture_model()

386
387
388
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
389
390
391
392
393
394
395
396
397
398
399
400
401
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
402
403
404
405
406
407
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
408
            kv_cache_memory_bytes_to_gpu_limit = (
409
410
411
412
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
413
            kv_cache_memory_bytes_to_requested_limit = (
414
415
416
417
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
418
419
420
421
422
423
424
425
426
427
428
429
430
431

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
432
433
434
435
436
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
437
                f"utilize gpu memory. Current kv cache memory in use is "
438
439
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
440

441
            logger.debug(msg)
442
443
444
445
446
447

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
448
        if get_pp_group().is_last_rank:
449
450
451
452
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
453

454
            # We skip EPLB here since we don't want to record dummy metrics
455
456
457
458
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
459
460
461
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
462
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
463

464
465
466
467
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

468
469
470
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

471
472
473
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

474
475
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
476

477
478
479
480
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
481
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
482
        intermediate_tensors = None
483
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
484
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
485
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
486
        all_gather_tensors = {
487
488
489
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
490
        }
491
        if forward_pass and not get_pp_group().is_first_rank:
492
493
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
494
                    all_gather_group=get_tp_group(),
495
496
497
                    all_gather_tensors=all_gather_tensors,
                )
            )
498

499
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
500
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
501
            return output
502

503
        assert isinstance(output, IntermediateTensors)
504
        parallel_config = self.vllm_config.parallel_config
505
506
507
508
        assert (
            parallel_config.distributed_executor_backend != ("external_launcher")
            and not get_pp_group().is_last_rank
        )
509

510
511
512
513
514
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
515
516
517
518
519
520
521

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
522
        if kv_connector_output.is_empty():
523
            return EMPTY_MODEL_RUNNER_OUTPUT
524

525
526
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
527
        return output
528

529
    def take_draft_token_ids(self) -> DraftTokenIds | None:
530
531
        return self.model_runner.take_draft_token_ids()

532
    def profile(self, is_start: bool = True):
533
534
535
536
537
538
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
539
540
            # only print profiler results on rank 0
            if self.local_rank == 0:
541
542
543
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
544

545
    def execute_dummy_batch(self) -> None:
546
        self.model_runner._dummy_run(1, uniform_decode=True)
547

548
549
550
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

551
552
553
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

554
    def list_loras(self) -> set[int]:
555
556
557
558
559
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

560
561
562
563
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

564
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
565
        from vllm.distributed.parallel_state import get_ep_group
566

567
        if get_ep_group().rank == 0:
568
569
570
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
571
572
573
574
575
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
576
577
578
579
580
581
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
582
583
584
585
586
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
587
588
589
        self,
        old_ep_size: int,
        new_ep_size: int,
590
        global_expert_load: torch.Tensor | None,
591
    ) -> None:
592
        from vllm.distributed.parallel_state import get_ep_group
593

594
        if get_ep_group().rank == 0:
595
596
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
597
598
599
600
601
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
602
603
            rank_mapping=rank_mapping,
        )
604
605
606
607
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
608
609
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
610
611
612
613
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
614
615
616
617
618
619
620
621
622
623
624
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
625
                reconfig_request.new_data_parallel_rank_local
626
627
            )
        parallel_config.data_parallel_master_ip = (
628
            reconfig_request.new_data_parallel_master_ip
629
630
        )
        parallel_config.data_parallel_master_port = (
631
            reconfig_request.new_data_parallel_master_port
632
        )
633

634
635
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
636
    ) -> torch.Tensor | None:
637
638
639
640
641
642
643
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
644
645
646
647
648
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
649
650
651

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
652
653
654
655
656
657
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
658
659
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
660
661
662
663
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
664
665
666
667
668
669
670
671
672
673
674
675
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
676
            new_physical_experts = (
677
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
678
            )
679
            parallel_config.eplb_config.num_redundant_experts = (
680
681
682
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
683
684
            global_expert_load = None
        else:
685
686
687
688
689
690
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
691
692
693
694
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
695
696
                self.model_runner.model, execute_shuffle=False
            )
697
            parallel_config.eplb_config.num_redundant_experts = (
698
699
                new_physical_experts - global_expert_load.shape[1]
            )
700
701
702
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
703
704
            num_local_physical_experts=num_local_physical_experts,
        )
705
706
707
        return global_expert_load

    def reinitialize_distributed(
708
709
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
710
711
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
712
713
714
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
715
716
717

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
718
719
720
721
722
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
723
724
725
726
727
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

728
729
730
731
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
732
733
734
735
736
737
738
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
739
740
741
742
743
744
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
745
746
747
748
749

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
750
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
751

752
753
754
    def save_sharded_state(
        self,
        path: str,
755
756
        pattern: str | None = None,
        max_size: int | None = None,
757
    ) -> None:
758
        from vllm.model_executor.model_loader import ShardedStateLoader
759

760
761
762
763
764
765
766
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

767
768
769
770
771
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
772
773
            tensorizer_config=tensorizer_config,
        )
774

775
    def shutdown(self) -> None:
776
777
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
778

779
780

def init_worker_distributed_environment(
781
    vllm_config: VllmConfig,
782
    rank: int,
783
    distributed_init_method: str | None = None,
784
    local_rank: int = -1,
785
    backend: str = "nccl",
786
787
) -> None:
    """Initialize the distributed environment."""
788
    parallel_config = vllm_config.parallel_config
789
790
791
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
792
793
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

794
795
796
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
797

798
799
800
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
801
802
        parallel_config.decode_context_parallel_size,
    )
803

804
    ensure_kv_transfer_initialized(vllm_config)