gpu_worker.py 31.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
23
24
25
26
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
27
from vllm.logger import init_logger
28
from vllm.lora.request import LoRARequest
29
from vllm.model_executor import set_random_seed
30
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
31
from vllm.platforms import current_platform
32
from vllm.sequence import IntermediateTensors
33
from vllm.tasks import SupportedTask
34
35
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
36
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
37
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
38
39
40
41
42
43
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
44
from vllm.v1.utils import report_usage_stats
45
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
46
from vllm.v1.worker.utils import is_residual_scattered_for_sp
47
from vllm.v1.worker.worker_base import WorkerBase
48
49
50
51

logger = init_logger(__name__)

if TYPE_CHECKING:
52
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
53
    from vllm.v1.core.sched.output import SchedulerOutput
54
55


56
class Worker(WorkerBase):
57
58
    def __init__(
        self,
59
        vllm_config: VllmConfig,
60
61
62
        local_rank: int,
        rank: int,
        distributed_init_method: str,
63
        is_driver_worker: bool = False,
64
    ):
65
66
67
68
69
70
71
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
72
73
74

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
75
            from vllm.utils.import_utils import init_cached_hf_modules
76

77
78
            init_cached_hf_modules()

79
80
81
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

82
83
84
85
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
86
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
87
88
89
90
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
91
92
93
94
95
96
97
98
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
99
100
101
102
103
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
104
105
106
107
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
108
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
109
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
110
111
                ),
            )
112
113
        else:
            self.profiler = None
114

115
    def sleep(self, level: int = 1) -> None:
116
117
        from vllm.device_allocator.cumem import CuMemAllocator

118
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
119
120
121
122
123

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
124
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
125
126
            }

127
        allocator = CuMemAllocator.get_instance()
128
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
129
130
131
132
133
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
134
135
136
137
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
138

139
    def wake_up(self, tags: list[str] | None = None) -> None:
140
141
        from vllm.device_allocator.cumem import CuMemAllocator

142
        allocator = CuMemAllocator.get_instance()
143
        allocator.wake_up(tags)
144

145
146
147
148
149
150
151
152
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

153
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
154
155
156
157
158
159
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
160
161
                    "Sleep mode can only be used for one instance per process."
                )
162
163
164
165
166
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

167
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
168
169
170
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

171
    def init_device(self):
172
173
174
175
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
176
            current_platform.set_device(self.device)
177

178
            current_platform.check_if_supports_dtype(self.model_config.dtype)
179
180
181
182
183

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
184
185
186
187
188
189
190
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
191
192
193
194
195

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
196
197
            gc.collect()
            torch.cuda.empty_cache()
198
199
200

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
201
202
203
204
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
205
            if self.init_snapshot.free_memory < self.requested_memory:
206
207
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
208
209
210
211
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
212
                    f"({self.cache_config.gpu_memory_utilization}, "
213
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
214
215
                    f"utilization or reduce GPU memory used by other processes."
                )
216
        else:
217
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
218

219
        # Construct the model runner
220
        self.model_runner: GPUModelRunner = GPUModelRunner(
221
222
            self.vllm_config, self.device
        )
223

224
225
226
227
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

228
229
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
230
    def load_model(self) -> None:
231
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
232
        with self._maybe_get_memory_pool_context(tag="weights"):
233
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
234

235
236
237
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

238
    def reload_weights(self) -> None:
239
        self.model_runner.reload_weights()
240

241
    @torch.inference_mode()
242
    def determine_available_memory(self) -> int:
243
        """Profiles the peak memory usage of the model to determine how much
244
        memory can be used for KV cache without OOMs.
245
246

        The engine will first conduct a profiling of the existing memory usage.
247
        Then, it calculates the free memory that can be used for KV cache in
248
        bytes.
249

250
251
252
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
253
        """
254
255
256
257
258
259
260
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
261
262
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
263
                "KV Cache as specified by kv_cache_memory_bytes config and "
264
                "skipped memory profiling. This does not respect the "
265
266
267
268
269
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
270
271
                "correspondingly."
            )
272
273
274
            logger.info(msg)
            return kv_cache_memory_bytes

275
        torch.cuda.empty_cache()
276
        torch.cuda.reset_peak_memory_stats()
277
278
279

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
280
        with memory_profiling(
281
282
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
283
        ) as profile_result:
284
            self.model_runner.profile_run()
285

286
287
288
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

289
        free_gpu_memory = profile_result.after_profile.free_memory
290
291
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
292
        assert self.init_snapshot.free_memory > free_gpu_memory, (
293
            "Error in memory profiling. "
294
295
296
297
298
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
299
300
301
302
303
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
304

305
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
306
        logger.debug(
307
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
308
309
310
311
312
313
314
315
316
317
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
318
        logger.debug(profile_result)
319
        logger.info_once(
320
321
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
322
            scope="local",
323
        )
324
        gc.collect()
325

326
        return int(self.available_kv_cache_memory_bytes)
327

328
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
329
330
        return self.model_runner.get_kv_cache_spec()

331
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
332
        """Allocate GPU KV cache with the specified kv_cache_config."""
333

334
335
336
337
338
339
340
341
342
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
        connector_vllm_config = copy.copy(self.vllm_config)
        connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
        ensure_kv_transfer_initialized(connector_vllm_config)

343
        if self.vllm_config.model_config.enable_sleep_mode:
344
345
            from vllm.device_allocator.cumem import CuMemAllocator

346
347
348
349
350
351
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
352
353

    def compile_or_warm_up_model(self) -> None:
354
355
356
357
358
359
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
360
361
362
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
363
            ]
364
        # We skip EPLB here since we don't want to record dummy metrics
365
366
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
367
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
368
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
369

370
371
372
373
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

374
        cuda_graph_memory_bytes = 0
375
        if not self.model_config.enforce_eager:
376
377
            cuda_graph_memory_bytes = self.model_runner.capture_model()

378
379
380
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
381
382
383
384
385
386
387
388
389
390
391
392
393
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
394
395
396
397
398
399
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
400
            kv_cache_memory_bytes_to_gpu_limit = (
401
402
403
404
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
405
            kv_cache_memory_bytes_to_requested_limit = (
406
407
408
409
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
410
411
412
413
414
415
416
417
418
419
420
421
422
423

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
424
425
426
427
428
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
429
                f"utilize gpu memory. Current kv cache memory in use is "
430
431
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
432

433
            logger.debug(msg)
434
435
436
437
438
439

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
440
        if get_pp_group().is_last_rank:
441
442
443
444
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
445

446
            # We skip EPLB here since we don't want to record dummy metrics
447
448
449
450
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
451
452
453
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
454
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
455

456
457
458
459
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

460
461
462
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

463
464
465
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

466
467
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
468

469
470
471
472
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
473
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
474
        intermediate_tensors = None
475
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
476
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
477
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
478
        all_gather_tensors = {
479
480
481
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
482
        }
483
        if forward_pass and not get_pp_group().is_first_rank:
484
485
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
486
                    all_gather_group=get_tp_group(),
487
488
489
                    all_gather_tensors=all_gather_tensors,
                )
            )
490

491
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
492
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
493
            return output
494

495
        assert isinstance(output, IntermediateTensors)
496
        parallel_config = self.vllm_config.parallel_config
497
498
499
500
        assert (
            parallel_config.distributed_executor_backend != ("external_launcher")
            and not get_pp_group().is_last_rank
        )
501

502
503
504
505
506
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
507
508
509
510
511
512
513

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
514
        if kv_connector_output.is_empty():
515
            return EMPTY_MODEL_RUNNER_OUTPUT
516

517
518
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
519
        return output
520

521
    def take_draft_token_ids(self) -> DraftTokenIds | None:
522
523
        return self.model_runner.take_draft_token_ids()

524
    def profile(self, is_start: bool = True):
525
526
527
528
529
530
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
531
532
            # only print profiler results on rank 0
            if self.local_rank == 0:
533
534
535
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
536

537
    def execute_dummy_batch(self) -> None:
538
        self.model_runner._dummy_run(1, uniform_decode=True)
539

540
541
542
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

543
544
545
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

546
    def list_loras(self) -> set[int]:
547
548
549
550
551
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

552
553
554
555
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

556
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
557
        from vllm.distributed.parallel_state import get_ep_group
558

559
        if get_ep_group().rank == 0:
560
561
562
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
563
564
565
566
567
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
568
569
570
571
572
573
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
574
575
576
577
578
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
579
580
581
        self,
        old_ep_size: int,
        new_ep_size: int,
582
        global_expert_load: torch.Tensor | None,
583
    ) -> None:
584
        from vllm.distributed.parallel_state import get_ep_group
585

586
        if get_ep_group().rank == 0:
587
588
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
589
590
591
592
593
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
594
595
            rank_mapping=rank_mapping,
        )
596
597
598
599
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
600
601
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
602
603
604
605
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
606
607
608
609
610
611
612
613
614
615
616
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
617
                reconfig_request.new_data_parallel_rank_local
618
619
            )
        parallel_config.data_parallel_master_ip = (
620
            reconfig_request.new_data_parallel_master_ip
621
622
        )
        parallel_config.data_parallel_master_port = (
623
            reconfig_request.new_data_parallel_master_port
624
        )
625

626
627
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
628
    ) -> torch.Tensor | None:
629
630
631
632
633
634
635
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
636
637
638
639
640
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
641
642
643

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
644
645
646
647
648
649
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
650
651
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
652
653
654
655
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
656
657
658
659
660
661
662
663
664
665
666
667
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
668
            new_physical_experts = (
669
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
670
            )
671
            parallel_config.eplb_config.num_redundant_experts = (
672
673
674
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
675
676
            global_expert_load = None
        else:
677
678
679
680
681
682
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
683
684
685
686
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
687
688
                self.model_runner.model, execute_shuffle=False
            )
689
            parallel_config.eplb_config.num_redundant_experts = (
690
691
                new_physical_experts - global_expert_load.shape[1]
            )
692
693
694
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
695
696
            num_local_physical_experts=num_local_physical_experts,
        )
697
698
699
        return global_expert_load

    def reinitialize_distributed(
700
701
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
702
703
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
704
705
706
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
707
708
709

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
710
711
712
713
714
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
715
716
717
718
719
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

720
721
722
723
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
724
725
726
727
728
729
730
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
731
732
733
734
735
736
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
737
738
739
740
741

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
742
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
743

744
745
746
    def save_sharded_state(
        self,
        path: str,
747
748
        pattern: str | None = None,
        max_size: int | None = None,
749
    ) -> None:
750
        from vllm.model_executor.model_loader import ShardedStateLoader
751

752
753
754
755
756
757
758
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

759
760
761
762
763
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
764
765
            tensorizer_config=tensorizer_config,
        )
766

767
    def shutdown(self) -> None:
768
769
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
770

771
772

def init_worker_distributed_environment(
773
    vllm_config: VllmConfig,
774
    rank: int,
775
    distributed_init_method: str | None = None,
776
    local_rank: int = -1,
777
    backend: str = "nccl",
778
779
) -> None:
    """Initialize the distributed environment."""
780
    parallel_config = vllm_config.parallel_config
781
782
783
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
784
785
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

786
787
788
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
789

790
791
792
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
793
794
        parallel_config.decode_context_parallel_size,
    )