gpu_worker.py 31 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
23
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
from vllm.model_executor import set_random_seed
27
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
28
from vllm.platforms import current_platform
29
from vllm.sequence import IntermediateTensors
30
from vllm.tasks import SupportedTask
31
32
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
33
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
34
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
35
36
37
38
39
40
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
41
from vllm.v1.utils import report_usage_stats
42
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
43
from vllm.v1.worker.utils import is_residual_scattered_for_sp
44
from vllm.v1.worker.worker_base import WorkerBase
45
46
47
48

logger = init_logger(__name__)

if TYPE_CHECKING:
49
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
50
    from vllm.v1.core.sched.output import SchedulerOutput
51
52


53
class Worker(WorkerBase):
54
55
    def __init__(
        self,
56
        vllm_config: VllmConfig,
57
58
59
        local_rank: int,
        rank: int,
        distributed_init_method: str,
60
        is_driver_worker: bool = False,
61
    ):
62
63
64
65
66
67
68
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
69
70
71
72

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
73

74
75
            init_cached_hf_modules()

76
77
78
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

79
80
81
82
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
83
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
84
85
86
87
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
88
89
90
91
92
93
94
95
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
96
97
98
99
100
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
101
102
103
104
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
105
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
106
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
107
108
                ),
            )
109
110
        else:
            self.profiler = None
111

112
    def sleep(self, level: int = 1) -> None:
113
114
        from vllm.device_allocator.cumem import CuMemAllocator

115
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
116
117
118
119
120

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
121
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
122
123
            }

124
        allocator = CuMemAllocator.get_instance()
125
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
126
127
128
129
130
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
131
132
133
134
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
135

136
    def wake_up(self, tags: list[str] | None = None) -> None:
137
138
        from vllm.device_allocator.cumem import CuMemAllocator

139
        allocator = CuMemAllocator.get_instance()
140
        allocator.wake_up(tags)
141

142
143
144
145
146
147
148
149
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

150
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
151
152
153
154
155
156
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
157
158
                    "Sleep mode can only be used for one instance per process."
                )
159
160
161
162
163
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

164
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
165
166
167
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

168
    def init_device(self):
169
170
171
172
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
173
            current_platform.set_device(self.device)
174

175
            current_platform.check_if_supports_dtype(self.model_config.dtype)
176
177
178
179
180

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
181
182
183
184
185
186
187
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
188
189
190
191
192

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
193
194
            gc.collect()
            torch.cuda.empty_cache()
195
196
197

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
198
199
200
201
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
202
            if self.init_snapshot.free_memory < self.requested_memory:
203
204
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
205
206
207
208
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
209
                    f"({self.cache_config.gpu_memory_utilization}, "
210
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
211
212
                    f"utilization or reduce GPU memory used by other processes."
                )
213
        else:
214
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
215

216
        # Construct the model runner
217
        self.model_runner: GPUModelRunner = GPUModelRunner(
218
219
            self.vllm_config, self.device
        )
220

221
222
223
224
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

225
226
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
227
    def load_model(self) -> None:
228
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
229
        with self._maybe_get_memory_pool_context(tag="weights"):
230
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
231

232
233
234
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

235
    def reload_weights(self) -> None:
236
        self.model_runner.reload_weights()
237

238
    @torch.inference_mode()
239
    def determine_available_memory(self) -> int:
240
        """Profiles the peak memory usage of the model to determine how much
241
        memory can be used for KV cache without OOMs.
242
243

        The engine will first conduct a profiling of the existing memory usage.
244
        Then, it calculates the free memory that can be used for KV cache in
245
        bytes.
246

247
248
249
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
250
        """
251
252
253
254
255
256
257
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
258
259
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
260
                "KV Cache as specified by kv_cache_memory_bytes config and "
261
                "skipped memory profiling. This does not respect the "
262
263
264
265
266
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
267
268
                "correspondingly."
            )
269
270
271
            logger.info(msg)
            return kv_cache_memory_bytes

272
        torch.cuda.empty_cache()
273
        torch.cuda.reset_peak_memory_stats()
274
275
276

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
277
        with memory_profiling(
278
279
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
280
        ) as profile_result:
281
            self.model_runner.profile_run()
282

283
284
285
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

286
        free_gpu_memory = profile_result.after_profile.free_memory
287
288
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
289
        assert self.init_snapshot.free_memory > free_gpu_memory, (
290
            "Error in memory profiling. "
291
292
293
294
295
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
296
297
298
299
300
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
301

302
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
303
        logger.debug(
304
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
305
306
307
308
309
310
311
312
313
314
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
315
        logger.debug(profile_result)
316
317
318
319
        logger.info(
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
        )
320
        gc.collect()
321

322
        return int(self.available_kv_cache_memory_bytes)
323

324
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
325
326
        return self.model_runner.get_kv_cache_spec()

327
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
328
        """Allocate GPU KV cache with the specified kv_cache_config."""
329

330
        if self.vllm_config.model_config.enable_sleep_mode:
331
332
            from vllm.device_allocator.cumem import CuMemAllocator

333
334
335
336
337
338
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
339
340

    def compile_or_warm_up_model(self) -> None:
341
342
343
344
345
346
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
347
348
349
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
350
            ]
351
        # We skip EPLB here since we don't want to record dummy metrics
352
353
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
354
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
355
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
356

357
358
359
360
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

361
        cuda_graph_memory_bytes = 0
362
        if not self.model_config.enforce_eager:
363
364
            cuda_graph_memory_bytes = self.model_runner.capture_model()

365
366
367
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
368
369
370
371
372
373
374
375
376
377
378
379
380
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
381
382
383
384
385
386
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
387
            kv_cache_memory_bytes_to_gpu_limit = (
388
389
390
391
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
392
            kv_cache_memory_bytes_to_requested_limit = (
393
394
395
396
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
397
398
399
400
401
402
403
404
405
406
407
408
409
410

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
411
412
413
414
415
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
416
                f"utilize gpu memory. Current kv cache memory in use is "
417
418
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
419

420
            logger.debug(msg)
421
422
423
424
425
426

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
427
        if get_pp_group().is_last_rank:
428
429
430
431
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
432

433
            # We skip EPLB here since we don't want to record dummy metrics
434
435
436
437
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
438
439
440
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
441
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
442

443
444
445
446
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

447
448
449
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

450
451
452
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

453
454
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
455

456
457
458
459
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
460
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
461
        intermediate_tensors = None
462
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
463
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
464
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
465
        all_gather_tensors = {
466
467
468
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
469
        }
470
        if forward_pass and not get_pp_group().is_first_rank:
471
472
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
473
                    all_gather_group=get_tp_group(),
474
475
476
                    all_gather_tensors=all_gather_tensors,
                )
            )
477

478
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
479
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
480
            return output
481

482
        assert isinstance(output, IntermediateTensors)
483
        parallel_config = self.vllm_config.parallel_config
484
485
486
487
        assert (
            parallel_config.distributed_executor_backend != ("external_launcher")
            and not get_pp_group().is_last_rank
        )
488

489
490
491
492
493
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
494
495
496
497
498
499
500

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
501
        if kv_connector_output.is_empty():
502
            return EMPTY_MODEL_RUNNER_OUTPUT
503

504
505
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
506
        return output
507

508
    def take_draft_token_ids(self) -> DraftTokenIds | None:
509
510
        return self.model_runner.take_draft_token_ids()

511
    def profile(self, is_start: bool = True):
512
513
514
515
516
517
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
518
519
            # only print profiler results on rank 0
            if self.local_rank == 0:
520
521
522
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
523

524
    def execute_dummy_batch(self) -> None:
525
        self.model_runner._dummy_run(1, uniform_decode=True)
526

527
528
529
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

530
531
532
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

533
    def list_loras(self) -> set[int]:
534
535
536
537
538
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

539
540
541
542
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

543
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
544
        from vllm.distributed.parallel_state import get_ep_group
545

546
        if get_ep_group().rank == 0:
547
548
549
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
550
551
552
553
554
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
555
556
557
558
559
560
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
561
562
563
564
565
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
566
567
568
        self,
        old_ep_size: int,
        new_ep_size: int,
569
        global_expert_load: torch.Tensor | None,
570
    ) -> None:
571
        from vllm.distributed.parallel_state import get_ep_group
572

573
        if get_ep_group().rank == 0:
574
575
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
576
577
578
579
580
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
581
582
            rank_mapping=rank_mapping,
        )
583
584
585
586
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
587
588
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
589
590
591
592
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
593
594
595
596
597
598
599
600
601
602
603
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
604
                reconfig_request.new_data_parallel_rank_local
605
606
            )
        parallel_config.data_parallel_master_ip = (
607
            reconfig_request.new_data_parallel_master_ip
608
609
        )
        parallel_config.data_parallel_master_port = (
610
            reconfig_request.new_data_parallel_master_port
611
        )
612

613
614
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
615
    ) -> torch.Tensor | None:
616
617
618
619
620
621
622
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
623
624
625
626
627
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
628
629
630

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
631
632
633
634
635
636
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
637
638
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
639
640
641
642
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
643
644
645
646
647
648
649
650
651
652
653
654
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
655
            new_physical_experts = (
656
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
657
            )
658
            parallel_config.eplb_config.num_redundant_experts = (
659
660
661
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
662
663
            global_expert_load = None
        else:
664
665
666
667
668
669
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
670
671
672
673
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
674
675
                self.model_runner.model, execute_shuffle=False
            )
676
            parallel_config.eplb_config.num_redundant_experts = (
677
678
                new_physical_experts - global_expert_load.shape[1]
            )
679
680
681
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
682
683
            num_local_physical_experts=num_local_physical_experts,
        )
684
685
686
        return global_expert_load

    def reinitialize_distributed(
687
688
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
689
690
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
691
692
693
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
694
695
696

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
697
698
699
700
701
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
702
703
704
705
706
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

707
708
709
710
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
711
712
713
714
715
716
717
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
718
719
720
721
722
723
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
724
725
726
727
728

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
729
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
730

731
732
733
    def save_sharded_state(
        self,
        path: str,
734
735
        pattern: str | None = None,
        max_size: int | None = None,
736
    ) -> None:
737
        from vllm.model_executor.model_loader import ShardedStateLoader
738

739
740
741
742
743
744
745
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

746
747
748
749
750
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
751
752
            tensorizer_config=tensorizer_config,
        )
753

754
    def shutdown(self) -> None:
755
756
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
757

758
759

def init_worker_distributed_environment(
760
    vllm_config: VllmConfig,
761
    rank: int,
762
    distributed_init_method: str | None = None,
763
    local_rank: int = -1,
764
    backend: str = "nccl",
765
766
) -> None:
    """Initialize the distributed environment."""
767
    parallel_config = vllm_config.parallel_config
768
769
770
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
771
772
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

773
774
775
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
776

777
778
779
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
780
781
        parallel_config.decode_context_parallel_size,
    )
782

783
    ensure_kv_transfer_initialized(vllm_config)