gpu_worker.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from typing import TYPE_CHECKING, Any, Optional, Union
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
23
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
from vllm.model_executor import set_random_seed
27
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
28
from vllm.platforms import current_platform
29
from vllm.sequence import IntermediateTensors
30
from vllm.tasks import SupportedTask
31
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
32
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
33
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
34
35
36
37
38
39
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
40
from vllm.v1.utils import report_usage_stats
41
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
42
from vllm.v1.worker.utils import is_residual_scattered_for_sp
43
from vllm.v1.worker.worker_base import WorkerBase
44
45
46
47

logger = init_logger(__name__)

if TYPE_CHECKING:
48
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
49
    from vllm.v1.core.sched.output import SchedulerOutput
50
51


52
class Worker(WorkerBase):
53
54
    def __init__(
        self,
55
        vllm_config: VllmConfig,
56
57
58
        local_rank: int,
        rank: int,
        distributed_init_method: str,
59
        is_driver_worker: bool = False,
60
    ):
61
62
63
64
65
66
67
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
68
69
70
71

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
72

73
74
            init_cached_hf_modules()

75
76
77
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

78
79
80
81
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
82
83
84
85
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
86
87
88
89
90
91
92
93
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
94
95
96
97
98
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
99
100
101
102
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
103
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
104
105
106
                    torch_profiler_trace_dir, use_gzip=True
                ),
            )
107
108
        else:
            self.profiler = None
109

110
    def sleep(self, level: int = 1) -> None:
111
112
        from vllm.device_allocator.cumem import CuMemAllocator

113
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
114
115
116
117
118

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
119
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
120
121
            }

122
        allocator = CuMemAllocator.get_instance()
123
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
124
125
126
127
128
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
129
130
131
132
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
133

134
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
135
136
        from vllm.device_allocator.cumem import CuMemAllocator

137
        allocator = CuMemAllocator.get_instance()
138
        allocator.wake_up(tags)
139

140
141
142
143
144
145
146
147
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

148
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
149
150
151
152
153
154
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
155
156
                    "Sleep mode can only be used for one instance per process."
                )
157
158
159
160
161
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

162
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
163
164
165
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

166
    def init_device(self):
167
168
169
170
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
171
            current_platform.set_device(self.device)
172

173
            current_platform.check_if_supports_dtype(self.model_config.dtype)
174
175
176
177
178

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
179
180
181
182
183
184
185
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
186
187
188
189
190

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
191
192
            gc.collect()
            torch.cuda.empty_cache()
193
194
195

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
196
197
198
199
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
200
            if self.init_snapshot.free_memory < self.requested_memory:
201
202
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
203
204
205
206
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
207
                    f"({self.cache_config.gpu_memory_utilization}, "
208
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
209
210
                    f"utilization or reduce GPU memory used by other processes."
                )
211
        else:
212
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
213

214
        # Construct the model runner
215
        self.model_runner: GPUModelRunner = GPUModelRunner(
216
217
            self.vllm_config, self.device
        )
218

219
220
221
222
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

223
224
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
225
    def load_model(self) -> None:
226
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
227
        with self._maybe_get_memory_pool_context(tag="weights"):
228
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
229

230
231
232
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

233
    def reload_weights(self) -> None:
234
        self.model_runner.reload_weights()
235

236
    @torch.inference_mode()
237
    def determine_available_memory(self) -> int:
238
        """Profiles the peak memory usage of the model to determine how much
239
        memory can be used for KV cache without OOMs.
240
241

        The engine will first conduct a profiling of the existing memory usage.
242
        Then, it calculates the free memory that can be used for KV cache in
243
        bytes.
244

245
246
247
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
248
        """
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
                f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
                "KV Cache as specified by kv_cache_memory_bytes config and "
                "skipped memory profiling. This does does not respect the "
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
265
266
                "correspondingly."
            )
267
268
269
            logger.info(msg)
            return kv_cache_memory_bytes

270
        torch.cuda.empty_cache()
271
        torch.cuda.reset_peak_memory_stats()
272
273
274

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
275
        with memory_profiling(
276
277
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
278
        ) as profile_result:
279
            self.model_runner.profile_run()
280

281
282
283
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

284
        free_gpu_memory = profile_result.after_profile.free_memory
285
286
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
287
        assert self.init_snapshot.free_memory > free_gpu_memory, (
288
            "Error in memory profiling. "
289
290
291
292
293
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
294
295
296
297
298
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
299

300
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
301
        logger.debug(
302
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
303
304
305
306
307
308
309
310
311
312
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
313
        logger.debug(profile_result)
314
315
316
317
        logger.info(
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
        )
318
        gc.collect()
319

320
        return int(self.available_kv_cache_memory_bytes)
321

322
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
323
324
        return self.model_runner.get_kv_cache_spec()

325
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
326
        """Allocate GPU KV cache with the specified kv_cache_config."""
327

328
        if self.vllm_config.model_config.enable_sleep_mode:
329
330
            from vllm.device_allocator.cumem import CuMemAllocator

331
332
333
334
335
336
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
337
338

    def compile_or_warm_up_model(self) -> None:
339
340
341
342
343
344
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
345
346
347
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
348
            ]
349
        # We skip EPLB here since we don't want to record dummy metrics
350
351
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
352
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
353
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
354

355
356
357
358
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

359
        cuda_graph_memory_bytes = 0
360
        if not self.model_config.enforce_eager:
361
362
            cuda_graph_memory_bytes = self.model_runner.capture_model()

363
364
365
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
366
367
368
369
370
371
372
373
374
375
376
377
378
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
379
380
381
382
383
384
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
385
            kv_cache_memory_bytes_to_gpu_limit = (
386
387
388
389
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
390
            kv_cache_memory_bytes_to_requested_limit = (
391
392
393
394
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
395
396
397
398
399
400
401
402
403
404
405
406
407
408

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
409
410
411
412
413
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
414
                f"utilize gpu memory. Current kv cache memory in use is "
415
416
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
417

418
            logger.debug(msg)
419
420
421
422
423
424

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
425
        if get_pp_group().is_last_rank:
426
427
428
429
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
430

431
            # We skip EPLB here since we don't want to record dummy metrics
432
433
434
435
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
436
437
438
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
439
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
440

441
442
443
444
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

445
446
447
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

448
449
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
450

451
452
453
454
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
455
    ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
456
        intermediate_tensors = None
457
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
458
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
459
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
460
        all_gather_tensors = {
461
462
463
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
464
        }
465
        if forward_pass and not get_pp_group().is_first_rank:
466
467
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
468
                    all_gather_group=get_tp_group(),
469
470
471
                    all_gather_tensors=all_gather_tensors,
                )
            )
472

473
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
474
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
475
            return output
476

477
        assert isinstance(output, IntermediateTensors)
478
        parallel_config = self.vllm_config.parallel_config
479
480
481
482
        assert (
            parallel_config.distributed_executor_backend != ("external_launcher")
            and not get_pp_group().is_last_rank
        )
483

484
485
486
487
488
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
489
490
491
492
493
494
495

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
496
        if kv_connector_output.is_empty():
497
            return EMPTY_MODEL_RUNNER_OUTPUT
498

499
500
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
501
        return output
502

503
504
505
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        return self.model_runner.take_draft_token_ids()

506
    def profile(self, is_start: bool = True):
507
508
509
510
511
512
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
513
514
            # only print profiler results on rank 0
            if self.local_rank == 0:
515
516
517
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
518

519
    def execute_dummy_batch(self) -> None:
520
        self.model_runner._dummy_run(1, uniform_decode=True)
521

522
523
524
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

525
526
527
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

528
    def list_loras(self) -> set[int]:
529
530
531
532
533
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

534
535
536
537
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

538
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
539
        from vllm.distributed.parallel_state import get_ep_group
540

541
        if get_ep_group().rank == 0:
542
543
544
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
545
546
547
548
549
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
550
551
552
553
554
555
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
556
557
558
559
560
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
561
562
563
564
565
        self,
        old_ep_size: int,
        new_ep_size: int,
        global_expert_load: Optional[torch.Tensor],
    ) -> None:
566
        from vllm.distributed.parallel_state import get_ep_group
567

568
        if get_ep_group().rank == 0:
569
570
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
571
572
573
574
575
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
576
577
            rank_mapping=rank_mapping,
        )
578
579
580
581
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
582
583
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
584
585
586
587
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
588
589
590
591
592
593
594
595
596
597
598
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
599
                reconfig_request.new_data_parallel_rank_local
600
601
            )
        parallel_config.data_parallel_master_ip = (
602
            reconfig_request.new_data_parallel_master_ip
603
604
        )
        parallel_config.data_parallel_master_port = (
605
            reconfig_request.new_data_parallel_master_port
606
        )
607

608
609
610
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
    ) -> Optional[torch.Tensor]:
611
612
613
614
615
616
617
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
618
619
620
621
622
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
623
624
625

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
626
627
628
629
630
631
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
632
633
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
634
635
636
637
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
638
639
640
641
642
643
644
645
646
647
648
649
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
650
            new_physical_experts = (
651
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
652
            )
653
            parallel_config.eplb_config.num_redundant_experts = (
654
655
656
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
657
658
            global_expert_load = None
        else:
659
660
661
662
663
664
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
665
666
667
668
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
669
670
                self.model_runner.model, execute_shuffle=False
            )
671
            parallel_config.eplb_config.num_redundant_experts = (
672
673
                new_physical_experts - global_expert_load.shape[1]
            )
674
675
676
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
677
678
            num_local_physical_experts=num_local_physical_experts,
        )
679
680
681
        return global_expert_load

    def reinitialize_distributed(
682
683
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
684
685
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
686
687
688
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
689
690
691

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
692
693
694
695
696
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
697
698
699
700
701
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

702
703
704
705
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
706
707
708
709
710
711
712
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
713
714
715
716
717
718
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
719
720
721
722
723

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
724
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
725

726
727
728
729
730
731
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
732
        from vllm.model_executor.model_loader import ShardedStateLoader
733

734
735
736
737
738
739
740
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

741
742
743
744
745
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
746
747
            tensorizer_config=tensorizer_config,
        )
748

749
    def shutdown(self) -> None:
750
751
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
752

753
754

def init_worker_distributed_environment(
755
    vllm_config: VllmConfig,
756
757
758
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
759
    backend: str = "nccl",
760
761
) -> None:
    """Initialize the distributed environment."""
762
    parallel_config = vllm_config.parallel_config
763
764
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

765
766
767
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
768

769
770
771
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
772
773
        parallel_config.decode_context_parallel_size,
    )
774

775
    ensure_kv_transfer_initialized(vllm_config)