gpu_worker.py 31 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
23
24
25
26
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
27
from vllm.logger import init_logger
28
from vllm.lora.request import LoRARequest
29
from vllm.model_executor import set_random_seed
30
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
31
from vllm.platforms import current_platform
32
from vllm.sequence import IntermediateTensors
33
from vllm.tasks import SupportedTask
34
35
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
36
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
37
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
38
39
40
41
42
43
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
44
from vllm.v1.utils import report_usage_stats
45
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
46
from vllm.v1.worker.utils import is_residual_scattered_for_sp
47
from vllm.v1.worker.worker_base import WorkerBase
48
49
50
51

logger = init_logger(__name__)

if TYPE_CHECKING:
52
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
53
    from vllm.v1.core.sched.output import SchedulerOutput
54
55


56
class Worker(WorkerBase):
57
58
    def __init__(
        self,
59
        vllm_config: VllmConfig,
60
61
62
        local_rank: int,
        rank: int,
        distributed_init_method: str,
63
        is_driver_worker: bool = False,
64
    ):
65
66
67
68
69
70
71
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
72
73
74
75

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
76

77
78
            init_cached_hf_modules()

79
80
81
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

82
83
84
85
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
86
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
87
88
89
90
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
91
92
93
94
95
96
97
98
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
99
100
101
102
103
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
104
105
106
107
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
108
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
109
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
110
111
                ),
            )
112
113
        else:
            self.profiler = None
114

115
    def sleep(self, level: int = 1) -> None:
116
117
        from vllm.device_allocator.cumem import CuMemAllocator

118
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
119
120
121
122
123

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
124
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
125
126
            }

127
        allocator = CuMemAllocator.get_instance()
128
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
129
130
131
132
133
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
134
135
136
137
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
138

139
    def wake_up(self, tags: list[str] | None = None) -> None:
140
141
        from vllm.device_allocator.cumem import CuMemAllocator

142
        allocator = CuMemAllocator.get_instance()
143
        allocator.wake_up(tags)
144

145
146
147
148
149
150
151
152
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

153
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
154
155
156
157
158
159
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
160
161
                    "Sleep mode can only be used for one instance per process."
                )
162
163
164
165
166
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

167
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
168
169
170
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

171
    def init_device(self):
172
173
174
175
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
176
            current_platform.set_device(self.device)
177

178
            current_platform.check_if_supports_dtype(self.model_config.dtype)
179
180
181
182
183

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
184
185
186
187
188
189
190
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
191
192
193
194
195

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
196
197
            gc.collect()
            torch.cuda.empty_cache()
198
199
200

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
201
202
203
204
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
205
            if self.init_snapshot.free_memory < self.requested_memory:
206
207
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
208
209
210
211
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
212
                    f"({self.cache_config.gpu_memory_utilization}, "
213
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
214
215
                    f"utilization or reduce GPU memory used by other processes."
                )
216
        else:
217
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
218

219
        # Construct the model runner
220
        self.model_runner: GPUModelRunner = GPUModelRunner(
221
222
            self.vllm_config, self.device
        )
223

224
225
226
227
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

228
229
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
230
    def load_model(self) -> None:
231
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
232
        with self._maybe_get_memory_pool_context(tag="weights"):
233
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
234

235
236
237
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

238
    def reload_weights(self) -> None:
239
        self.model_runner.reload_weights()
240

241
    @torch.inference_mode()
242
    def determine_available_memory(self) -> int:
243
        """Profiles the peak memory usage of the model to determine how much
244
        memory can be used for KV cache without OOMs.
245
246

        The engine will first conduct a profiling of the existing memory usage.
247
        Then, it calculates the free memory that can be used for KV cache in
248
        bytes.
249

250
251
252
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
253
        """
254
255
256
257
258
259
260
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
261
262
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
263
                "KV Cache as specified by kv_cache_memory_bytes config and "
264
                "skipped memory profiling. This does not respect the "
265
266
267
268
269
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
270
271
                "correspondingly."
            )
272
273
274
            logger.info(msg)
            return kv_cache_memory_bytes

275
        torch.cuda.empty_cache()
276
        torch.cuda.reset_peak_memory_stats()
277
278
279

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
280
        with memory_profiling(
281
282
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
283
        ) as profile_result:
284
            self.model_runner.profile_run()
285

286
287
288
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

289
        free_gpu_memory = profile_result.after_profile.free_memory
290
291
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
292
        assert self.init_snapshot.free_memory > free_gpu_memory, (
293
            "Error in memory profiling. "
294
295
296
297
298
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
299
300
301
302
303
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
304

305
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
306
        logger.debug(
307
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
308
309
310
311
312
313
314
315
316
317
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
318
        logger.debug(profile_result)
319
        logger.info_once(
320
321
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
322
            scope="local",
323
        )
324
        gc.collect()
325

326
        return int(self.available_kv_cache_memory_bytes)
327

328
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
329
330
        return self.model_runner.get_kv_cache_spec()

331
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
332
        """Allocate GPU KV cache with the specified kv_cache_config."""
333

334
        if self.vllm_config.model_config.enable_sleep_mode:
335
336
            from vllm.device_allocator.cumem import CuMemAllocator

337
338
339
340
341
342
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
343
344

    def compile_or_warm_up_model(self) -> None:
345
346
347
348
349
350
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
351
352
353
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
354
            ]
355
        # We skip EPLB here since we don't want to record dummy metrics
356
357
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
358
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
359
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
360

361
362
363
364
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

365
        cuda_graph_memory_bytes = 0
366
        if not self.model_config.enforce_eager:
367
368
            cuda_graph_memory_bytes = self.model_runner.capture_model()

369
370
371
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
372
373
374
375
376
377
378
379
380
381
382
383
384
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
385
386
387
388
389
390
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
391
            kv_cache_memory_bytes_to_gpu_limit = (
392
393
394
395
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
396
            kv_cache_memory_bytes_to_requested_limit = (
397
398
399
400
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
401
402
403
404
405
406
407
408
409
410
411
412
413
414

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
415
416
417
418
419
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
420
                f"utilize gpu memory. Current kv cache memory in use is "
421
422
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
423

424
            logger.debug(msg)
425
426
427
428
429
430

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
431
        if get_pp_group().is_last_rank:
432
433
434
435
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
436

437
            # We skip EPLB here since we don't want to record dummy metrics
438
439
440
441
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
442
443
444
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
445
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
446

447
448
449
450
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

451
452
453
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

454
455
456
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

457
458
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
459

460
461
462
463
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
464
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
465
        intermediate_tensors = None
466
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
467
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
468
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
469
        all_gather_tensors = {
470
471
472
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
473
        }
474
        if forward_pass and not get_pp_group().is_first_rank:
475
476
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
477
                    all_gather_group=get_tp_group(),
478
479
480
                    all_gather_tensors=all_gather_tensors,
                )
            )
481

482
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
483
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
484
            return output
485

486
        assert isinstance(output, IntermediateTensors)
487
        parallel_config = self.vllm_config.parallel_config
488
489
490
491
        assert (
            parallel_config.distributed_executor_backend != ("external_launcher")
            and not get_pp_group().is_last_rank
        )
492

493
494
495
496
497
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
498
499
500
501
502
503
504

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
505
        if kv_connector_output.is_empty():
506
            return EMPTY_MODEL_RUNNER_OUTPUT
507

508
509
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
510
        return output
511

512
    def take_draft_token_ids(self) -> DraftTokenIds | None:
513
514
        return self.model_runner.take_draft_token_ids()

515
    def profile(self, is_start: bool = True):
516
517
518
519
520
521
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
522
523
            # only print profiler results on rank 0
            if self.local_rank == 0:
524
525
526
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
527

528
    def execute_dummy_batch(self) -> None:
529
        self.model_runner._dummy_run(1, uniform_decode=True)
530

531
532
533
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

534
535
536
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

537
    def list_loras(self) -> set[int]:
538
539
540
541
542
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

543
544
545
546
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

547
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
548
        from vllm.distributed.parallel_state import get_ep_group
549

550
        if get_ep_group().rank == 0:
551
552
553
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
554
555
556
557
558
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
559
560
561
562
563
564
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
565
566
567
568
569
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
570
571
572
        self,
        old_ep_size: int,
        new_ep_size: int,
573
        global_expert_load: torch.Tensor | None,
574
    ) -> None:
575
        from vllm.distributed.parallel_state import get_ep_group
576

577
        if get_ep_group().rank == 0:
578
579
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
580
581
582
583
584
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
585
586
            rank_mapping=rank_mapping,
        )
587
588
589
590
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
591
592
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
593
594
595
596
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
597
598
599
600
601
602
603
604
605
606
607
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
608
                reconfig_request.new_data_parallel_rank_local
609
610
            )
        parallel_config.data_parallel_master_ip = (
611
            reconfig_request.new_data_parallel_master_ip
612
613
        )
        parallel_config.data_parallel_master_port = (
614
            reconfig_request.new_data_parallel_master_port
615
        )
616

617
618
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
619
    ) -> torch.Tensor | None:
620
621
622
623
624
625
626
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
627
628
629
630
631
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
632
633
634

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
635
636
637
638
639
640
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
641
642
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
643
644
645
646
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
647
648
649
650
651
652
653
654
655
656
657
658
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
659
            new_physical_experts = (
660
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
661
            )
662
            parallel_config.eplb_config.num_redundant_experts = (
663
664
665
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
666
667
            global_expert_load = None
        else:
668
669
670
671
672
673
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
674
675
676
677
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
678
679
                self.model_runner.model, execute_shuffle=False
            )
680
            parallel_config.eplb_config.num_redundant_experts = (
681
682
                new_physical_experts - global_expert_load.shape[1]
            )
683
684
685
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
686
687
            num_local_physical_experts=num_local_physical_experts,
        )
688
689
690
        return global_expert_load

    def reinitialize_distributed(
691
692
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
693
694
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
695
696
697
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
698
699
700

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
701
702
703
704
705
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
706
707
708
709
710
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

711
712
713
714
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
715
716
717
718
719
720
721
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
722
723
724
725
726
727
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
728
729
730
731
732

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
733
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
734

735
736
737
    def save_sharded_state(
        self,
        path: str,
738
739
        pattern: str | None = None,
        max_size: int | None = None,
740
    ) -> None:
741
        from vllm.model_executor.model_loader import ShardedStateLoader
742

743
744
745
746
747
748
749
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

750
751
752
753
754
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
755
756
            tensorizer_config=tensorizer_config,
        )
757

758
    def shutdown(self) -> None:
759
760
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
761

762
763

def init_worker_distributed_environment(
764
    vllm_config: VllmConfig,
765
    rank: int,
766
    distributed_init_method: str | None = None,
767
    local_rank: int = -1,
768
    backend: str = "nccl",
769
770
) -> None:
    """Initialize the distributed environment."""
771
    parallel_config = vllm_config.parallel_config
772
773
774
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
775
776
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

777
778
779
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
780

781
782
783
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
784
785
        parallel_config.decode_context_parallel_size,
    )
786

787
    ensure_kv_transfer_initialized(vllm_config)