gpu_worker.py 31.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4
import copy
5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from typing import TYPE_CHECKING, Any, Optional, Union
9
10
11

import torch
import torch.distributed
12
import torch.nn as nn
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
17
18
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment,
                              set_custom_all_reduce)
19
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
20
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
from vllm.model_executor import set_random_seed
24
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
25
from vllm.platforms import current_platform
26
from vllm.sequence import IntermediateTensors
27
from vllm.tasks import SupportedTask
28
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
29
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
30
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
31
32
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
                             DraftTokenIds, ModelRunnerOutput)
33
from vllm.v1.utils import report_usage_stats
34
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
35
from vllm.v1.worker.utils import is_residual_scattered_for_sp
36
from vllm.v1.worker.worker_base import WorkerBase
37
38
39
40

logger = init_logger(__name__)

if TYPE_CHECKING:
41
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
42
    from vllm.v1.core.sched.output import SchedulerOutput
43
44


45
class Worker(WorkerBase):
46
47
48

    def __init__(
        self,
49
        vllm_config: VllmConfig,
50
51
52
        local_rank: int,
        rank: int,
        distributed_init_method: str,
53
        is_driver_worker: bool = False,
54
    ):
55

56
57
58
59
60
        super().__init__(vllm_config=vllm_config,
                         local_rank=local_rank,
                         rank=rank,
                         distributed_init_method=distributed_init_method,
                         is_driver_worker=is_driver_worker)
61
62
63
64
65
66

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

67
68
69
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

70
71
72
73
74
75
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
76
77
78
79
80
81
82
83
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
84
85
86
87
88
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
89
90
91
92
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
93
94
95
96
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None
97

98
    def sleep(self, level: int = 1) -> None:
99
100
        from vllm.device_allocator.cumem import CuMemAllocator

101
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
102
103
104
105
106
107
108
109
110

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
                name: buffer.cpu().clone()
                for name, buffer in model.named_buffers()
            }

111
112
113
114
115
116
117
118
119
120
121
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

122
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
123
124
        from vllm.device_allocator.cumem import CuMemAllocator

125
        allocator = CuMemAllocator.get_instance()
126
        allocator.wake_up(tags)
127

128
129
130
131
132
133
134
135
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    def _maybe_get_memory_pool_context(self,
                                       tag: str) -> AbstractContextManager:
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
                    "Sleep mode can only be "
                    "used for one instance per process.")
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

151
152
153
154
155
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

156
    def init_device(self):
157
158
159
160
161
162
163
164
165
166
167
168
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
169
            current_platform.set_device(self.device)
170

171
            current_platform.check_if_supports_dtype(self.model_config.dtype)
172
173
            gc.collect()
            torch.cuda.empty_cache()
174
175
176
177
178
179

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
            self.requested_memory = (self.init_snapshot.total_memory *
                                     self.cache_config.gpu_memory_utilization)
            if self.init_snapshot.free_memory < self.requested_memory:
180
181
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
182
183
184
185
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
186
                    f"({self.cache_config.gpu_memory_utilization}, "
187
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
188
189
                    f"utilization or reduce GPU memory used by other processes."
                )
190
191
192
193
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
194
        init_worker_distributed_environment(self.vllm_config, self.rank,
195
                                            self.distributed_init_method,
196
197
                                            self.local_rank,
                                            current_platform.dist_backend)
198
199
200
        # Set random seed.
        set_random_seed(self.model_config.seed)

201
        # Construct the model runner
202
203
        self.model_runner: GPUModelRunner = GPUModelRunner(
            self.vllm_config, self.device)
204

205
206
207
208
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

209
210
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
211
    def load_model(self) -> None:
212
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
213
        with self._maybe_get_memory_pool_context(tag="weights"):
214
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
215

216
217
218
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

219
    def reload_weights(self) -> None:
220
        self.model_runner.reload_weights()
221

222
    @torch.inference_mode()
223
    def determine_available_memory(self) -> int:
224
        """Profiles the peak memory usage of the model to determine how much
225
        memory can be used for KV cache without OOMs.
226
227

        The engine will first conduct a profiling of the existing memory usage.
228
        Then, it calculates the free memory that can be used for KV cache in
229
        bytes.
230

231
232
233
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
234
        """
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
                f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
                "KV Cache as specified by kv_cache_memory_bytes config and "
                "skipped memory profiling. This does does not respect the "
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
                "correspondingly.")
            logger.info(msg)
            return kv_cache_memory_bytes

255
        torch.cuda.empty_cache()
256
        torch.cuda.reset_peak_memory_stats()
257
258
259

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
260
261
        with memory_profiling(
                self.init_snapshot,
262
263
                weights_memory=int(self.model_runner.model_memory_usage),
        ) as profile_result:
264
            self.model_runner.profile_run()
265

266
267
268
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

269
        free_gpu_memory = profile_result.after_profile.free_memory
270
271
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
272
        assert self.init_snapshot.free_memory > free_gpu_memory, (
273
            "Error in memory profiling. "
274
275
276
277
278
279
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
            "isolate vLLM in its own container.")
280
        self.available_kv_cache_memory_bytes = self.requested_memory \
281
            - profile_result.non_kv_cache_memory
282

283
284
        unrequested_memory = self.init_snapshot.free_memory \
            - self.requested_memory
285
        logger.debug(
286
287
288
289
290
291
292
293
294
295
296
297
            "Initial free memory: %.2f GiB; "
            "Requested memory: %.2f (util), %.2f GiB",
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
298
299
        logger.debug(profile_result)
        logger.info("Available KV cache memory: %.2f GiB",
300
                    GiB(self.available_kv_cache_memory_bytes))
301
        gc.collect()
302

303
        return int(self.available_kv_cache_memory_bytes)
304

305
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
306
307
        return self.model_runner.get_kv_cache_spec()

308
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
309
        """Allocate GPU KV cache with the specified kv_cache_config."""
310

311
        if self.vllm_config.model_config.enable_sleep_mode:
312
313
            from vllm.device_allocator.cumem import CuMemAllocator

314
315
316
317
318
319
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
320
321

    def compile_or_warm_up_model(self) -> None:
322
323
324
325
326
327
328
329
330
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
331
        # We skip EPLB here since we don't want to record dummy metrics
332
333
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
334
335
336
337
            self.model_runner._dummy_run(size,
                                         skip_eplb=True,
                                         remove_lora=False)
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
338

339
340
341
342
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

343
        cuda_graph_memory_bytes = 0
344
        if not self.model_config.enforce_eager:
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
            cuda_graph_memory_bytes = self.model_runner.capture_model()

        if (self.cache_config.kv_cache_memory_bytes is None
                and hasattr(self, "peak_activation_memory")):
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
            non_kv_cache_memory = (self.model_runner.model_memory_usage +
                                   self.peak_activation_memory +
                                   self.non_torch_memory +
                                   cuda_graph_memory_bytes)
            kv_cache_memory_bytes_to_gpu_limit = (
                self.init_snapshot.free_memory - non_kv_cache_memory -
                redundancy_buffer_memory)
            kv_cache_memory_bytes_to_requested_limit = (
                int(self.requested_memory) - non_kv_cache_memory -
                redundancy_buffer_memory)

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
                f"requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
                f"utilize gpu memory. Current kv cache memory in use is "
                f"{int(self.available_kv_cache_memory_bytes)} bytes.")

            logger.info(msg)
393
394
395
396
397
398

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
399
        if get_pp_group().is_last_rank:
400
401
            max_num_reqs = min(self.scheduler_config.max_num_seqs,
                               self.scheduler_config.max_num_batched_tokens)
402

403
            # We skip EPLB here since we don't want to record dummy metrics
404
            hidden_states, last_hidden_states = \
405
406
407
408
                self.model_runner._dummy_run(
                    num_tokens=max_num_reqs,
                    skip_eplb=True,
                )
409
410
411
412
413
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
                self.model_runner._dummy_sampler_run(
                    hidden_states=last_hidden_states)
414

415
416
417
418
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

419
420
421
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

422
423
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
424

425
426
427
428
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
429
    ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
430
        intermediate_tensors = None
431
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
432
433
434
435
436
437
438
439
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        num_input_tokens = self.model_runner._get_num_input_tokens(
            num_scheduled_tokens)
        all_gather_tensors = {
            "residual":
            not is_residual_scattered_for_sp(self.vllm_config,
                                             num_input_tokens)
        }
440
        if forward_pass and not get_pp_group().is_first_rank:
441
442
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
443
444
                    all_gather_group=get_tp_group(),
                    all_gather_tensors=all_gather_tensors))
445
446
447

        output = self.model_runner.execute_model(scheduler_output,
                                                 intermediate_tensors)
448
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
449
            return output
450

451
        assert isinstance(output, IntermediateTensors)
452
        parallel_config = self.vllm_config.parallel_config
453
454
455
456
        assert parallel_config.distributed_executor_backend != (
            "external_launcher") and not get_pp_group().is_last_rank

        get_pp_group().send_tensor_dict(output.tensors,
457
458
                                        all_gather_group=get_tp_group(),
                                        all_gather_tensors=all_gather_tensors)
459
460
461
462
463
464
465
466
467
468

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
        if (not kv_connector_output.finished_sending
                and not kv_connector_output.finished_recving):
            return EMPTY_MODEL_RUNNER_OUTPUT
469

470
471
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
472
        return output
473

474
475
476
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        return self.model_runner.take_draft_token_ids()

477
    def profile(self, is_start: bool = True):
478
479
480
481
482
483
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
484
485
486
487
            # only print profiler results on rank 0
            if self.local_rank == 0:
                print(self.profiler.key_averages().table(
                    sort_by="self_cuda_time_total"))
488

489
490
491
    def execute_dummy_batch(self) -> None:
        self.model_runner._dummy_run(1)

492
493
494
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

495
496
497
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

498
    def list_loras(self) -> set[int]:
499
500
501
502
503
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

504
505
506
507
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def _eplb_before_scale_down(self, old_ep_size: int,
                                new_ep_size: int) -> None:
        from vllm.distributed.parallel_state import get_ep_group
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding "
                        "before scaling down...")
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(self.model_runner.model,
                                               execute_shuffle=True,
                                               global_expert_load=None,
                                               rank_mapping=rank_mapping)
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
            self, old_ep_size: int, new_ep_size: int,
            global_expert_load: Optional[torch.Tensor]) -> None:
        from vllm.distributed.parallel_state import get_ep_group
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding "
                        "after scaling up...")
        rank_mapping = {
            old_ep_rank: old_ep_rank
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
            rank_mapping=rank_mapping)
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != \
        ReconfigureRankType.KEEP_CURRENT_RANK:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        if reconfig_request.new_data_parallel_rank_local != \
        ReconfigureRankType.KEEP_CURRENT_RANK:
            parallel_config.data_parallel_rank_local = \
                reconfig_request.new_data_parallel_rank_local
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port

    def _reconfigure_moe(self, old_ep_size: int,
                         new_ep_size: int) -> Optional[torch.Tensor]:
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
            get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoEParallelConfig)

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
            module for module in self.model_runner.model.modules()
584
585
            if (module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE")
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
        assert all(module.moe_config.num_local_experts == num_local_experts
                   for module in moe_modules), (
                       "All MoE modules must have the same number of experts")
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
            new_physical_experts = \
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
605
            parallel_config.eplb_config.num_redundant_experts = (
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
                new_physical_experts -
                self.model_runner.eplb_state.logical_replica_count.shape[1])
            global_expert_load = None
        else:
            num_local_physical_experts = torch.tensor([num_local_experts],
                                                      dtype=torch.int32,
                                                      device="cpu")
            torch.distributed.broadcast(num_local_physical_experts,
                                        group=get_ep_group().cpu_group,
                                        group_src=0)
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
                self.model_runner.model, execute_shuffle=False)
621
            parallel_config.eplb_config.num_redundant_experts = (
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
                new_physical_experts - global_expert_load.shape[1])
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
            num_local_physical_experts=num_local_physical_experts)
        return global_expert_load

    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
            cleanup_dist_env_and_memory, get_ep_group)

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
        new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
        ).world_size * get_pp_group().world_size
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
            init_worker_distributed_environment(self.vllm_config, self.rank,
                                                self.distributed_init_method,
                                                self.local_rank)

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size,
                                      global_expert_load)

664
665
666
667
668
669
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
670
        from vllm.model_executor.model_loader import ShardedStateLoader
671
672
673
674
675
676
677
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

678
679
680
681
682
683
684
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

685
686
687
    def shutdown(self) -> None:
        self.model_runner.ensure_kv_transfer_shutdown()

688
689

def init_worker_distributed_environment(
690
    vllm_config: VllmConfig,
691
692
693
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
694
    backend: str = "nccl",
695
696
) -> None:
    """Initialize the distributed environment."""
697
    parallel_config = vllm_config.parallel_config
698
699
700
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

    init_distributed_environment(parallel_config.world_size, rank,
701
                                 distributed_init_method, local_rank, backend)
702

703
704
705
706
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
        parallel_config.decode_context_parallel_size)
707

708
    ensure_kv_transfer_initialized(vllm_config)