gpu_worker.py 34.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
from vllm.distributed.parallel_state import (
29
    get_pcp_group,
30
31
32
    get_pp_group,
    get_tp_group,
)
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
from vllm.model_executor import set_random_seed
36
from vllm.model_executor.models.interfaces import is_mixture_of_experts
37
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
38
from vllm.platforms import current_platform
39
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
40
from vllm.sequence import IntermediateTensors
41
from vllm.tasks import SupportedTask
42
43
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
44
from vllm.v1.core.sched.output import GrammarOutput
45
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
46
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
47
48
49
50
51
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
52
from vllm.v1.utils import report_usage_stats
53
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
54
from vllm.v1.worker.utils import is_residual_scattered_for_sp
55
from vllm.v1.worker.worker_base import WorkerBase
56
57
58
59

logger = init_logger(__name__)

if TYPE_CHECKING:
60
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
61
    from vllm.v1.core.sched.output import SchedulerOutput
62
63


64
class Worker(WorkerBase):
65
66
    def __init__(
        self,
67
        vllm_config: VllmConfig,
68
69
70
        local_rank: int,
        rank: int,
        distributed_init_method: str,
71
        is_driver_worker: bool = False,
72
    ):
73
74
75
76
77
78
79
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
80
81
82

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
83
            from vllm.utils.import_utils import init_cached_hf_modules
84

85
86
            init_cached_hf_modules()

87
88
89
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

90
91
92
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
93
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
94
95
            self.profiler = TorchProfilerWrapper(
                worker_name=worker_name, local_rank=self.local_rank
96
            )
97
98
        elif envs.VLLM_TORCH_CUDA_PROFILE:
            self.profiler = CudaProfilerWrapper()
99
100
        else:
            self.profiler = None
101

102
    def sleep(self, level: int = 1) -> None:
103
104
        from vllm.device_allocator.cumem import CuMemAllocator

105
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
106
107
108
109
110

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
111
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
112
113
            }

114
        allocator = CuMemAllocator.get_instance()
115
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
116
117
118
119
120
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
121
122
123
124
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
125

126
    def wake_up(self, tags: list[str] | None = None) -> None:
127
128
        from vllm.device_allocator.cumem import CuMemAllocator

129
        allocator = CuMemAllocator.get_instance()
130
        allocator.wake_up(tags)
131

132
133
134
135
136
137
138
139
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

140
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
141
142
143
144
145
146
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
147
148
                    "Sleep mode can only be used for one instance per process."
                )
149
150
151
152
153
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

154
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
155
156
157
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

158
    def init_device(self):
159
160
161
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
162
163
164
165
166
167
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
168
                and self.vllm_config.parallel_config.nnodes_within_dp == 1
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
185
186
187
188
189
190
191
192
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
193
            self.device = torch.device(f"cuda:{self.local_rank}")
194
            current_platform.set_device(self.device)
195

196
            current_platform.check_if_supports_dtype(self.model_config.dtype)
197
198
199
200
201

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
202
203
204
205
206
207
208
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
209
210
211
212
213

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
214
215
            gc.collect()
            torch.cuda.empty_cache()
216
217
218

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
219
220
221
222
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
223
            if self.init_snapshot.free_memory < self.requested_memory:
224
225
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
226
227
228
229
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
230
                    f"({self.cache_config.gpu_memory_utilization}, "
231
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
232
233
                    f"utilization or reduce GPU memory used by other processes."
                )
234
        else:
235
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
236

237
        # Construct the model runner
238
        self.model_runner: GPUModelRunner = GPUModelRunner(
239
240
            self.vllm_config, self.device
        )
241

242
243
244
245
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

246
247
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
248
    def load_model(self) -> None:
249
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
250
        with self._maybe_get_memory_pool_context(tag="weights"):
251
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
252

253
254
255
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

256
    def reload_weights(self) -> None:
257
        self.model_runner.reload_weights()
258

259
    @torch.inference_mode()
260
    def determine_available_memory(self) -> int:
261
        """Profiles the peak memory usage of the model to determine how much
262
        memory can be used for KV cache without OOMs.
263
264

        The engine will first conduct a profiling of the existing memory usage.
265
        Then, it calculates the free memory that can be used for KV cache in
266
        bytes.
267

268
269
270
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
271
        """
272
273
274
275
276
277
278
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
279
280
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
281
                "KV Cache as specified by kv_cache_memory_bytes config and "
282
                "skipped memory profiling. This does not respect the "
283
284
285
286
287
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
288
289
                "correspondingly."
            )
290
291
292
            logger.info(msg)
            return kv_cache_memory_bytes

293
        torch.cuda.empty_cache()
294
        torch.cuda.reset_peak_memory_stats()
295
296
297

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
298
        with memory_profiling(
299
300
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
301
        ) as profile_result:
302
            self.model_runner.profile_run()
303

304
305
306
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

307
        free_gpu_memory = profile_result.after_profile.free_memory
308
309
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
310
        assert self.init_snapshot.free_memory > free_gpu_memory, (
311
            "Error in memory profiling. "
312
313
314
315
316
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
317
318
319
320
321
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
322

323
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
324
        logger.debug(
325
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
326
327
328
329
330
331
332
333
334
335
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
336
        logger.debug(profile_result)
337
        logger.info_once(
338
339
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
340
            scope="local",
341
        )
342
        gc.collect()
343

344
        return int(self.available_kv_cache_memory_bytes)
345

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

361
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
362
363
        return self.model_runner.get_kv_cache_spec()

364
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
365
        """Allocate GPU KV cache with the specified kv_cache_config."""
366

367
368
369
370
371
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
372
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
373

374
        if self.vllm_config.model_config.enable_sleep_mode:
375
376
            from vllm.device_allocator.cumem import CuMemAllocator

377
378
379
380
381
382
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
383
384

    def compile_or_warm_up_model(self) -> None:
385
386
387
388
389
390
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
391
392
393
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
394
            ]
395
        # We skip EPLB here since we don't want to record dummy metrics
396
397
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
398
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
399
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
400

401
402
403
404
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

405
        cuda_graph_memory_bytes = 0
406
        if not self.model_config.enforce_eager:
407
408
            cuda_graph_memory_bytes = self.model_runner.capture_model()

409
410
411
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
412
413
414
415
416
417
418
419
420
421
422
423
424
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
425
426
427
428
429
430
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
431
            kv_cache_memory_bytes_to_gpu_limit = (
432
433
434
435
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
436
            kv_cache_memory_bytes_to_requested_limit = (
437
438
439
440
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
441
442
443
444
445
446
447
448
449
450
451
452
453
454

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
455
456
457
458
459
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
460
                f"utilize gpu memory. Current kv cache memory in use is "
461
462
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
463

464
            logger.debug(msg)
465
466
467
468
469
470

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
471
        if get_pp_group().is_last_rank:
472
473
474
475
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
476

477
            # We skip EPLB here since we don't want to record dummy metrics
478
479
480
481
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
482
483
484
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
485
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
486

487
488
489
490
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

491
492
493
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

494
495
496
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

497
498
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
499

500
501
502
503
504
505
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

506
507
        self.profiler.step()

508
509
510
        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

511
        return self.profiler.annotate_context_manager(
512
513
514
            f"execute_new_{num_new}_cached_{num_cached}"
        )

515
516
    @torch.inference_mode()
    def sample_tokens(
517
        self, grammar_output: "GrammarOutput | None"
518
519
520
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

521
522
    @torch.inference_mode()
    def execute_model(
523
524
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
525
        intermediate_tensors = None
526
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
527
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
528
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
529
        all_gather_tensors = {
530
531
532
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
533
        }
534
        if forward_pass and not get_pp_group().is_first_rank:
535
536
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
537
                    all_gather_group=get_tp_group(),
538
539
540
                    all_gather_tensors=all_gather_tensors,
                )
            )
541

542
543
544
545
546
547
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
548

549
        assert isinstance(output, IntermediateTensors)
550
        parallel_config = self.vllm_config.parallel_config
551
        assert (
552
            parallel_config.distributed_executor_backend != "external_launcher"
553
554
            and not get_pp_group().is_last_rank
        )
555

556
557
558
559
560
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
561

562
        return None
563

564
    def take_draft_token_ids(self) -> DraftTokenIds | None:
565
566
        return self.model_runner.take_draft_token_ids()

567
    def profile(self, is_start: bool = True):
568
        if self.profiler is None:
569
            raise RuntimeError("Profiling is not enabled.")
570
571
572
573
574
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

575
    def execute_dummy_batch(self) -> None:
576
        self.model_runner._dummy_run(1, uniform_decode=True)
577

578
579
580
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

581
582
583
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

584
    def list_loras(self) -> set[int]:
585
586
587
588
589
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

590
591
592
593
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

594
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
595
        from vllm.distributed.parallel_state import get_ep_group
596

597
        if get_ep_group().rank == 0:
598
599
600
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
601
602
603
604
605
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
606
607
608
609
610
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
611
612
613
614
615
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
616
617
618
        self,
        old_ep_size: int,
        new_ep_size: int,
619
        global_expert_loads: list[torch.Tensor] | None,
620
    ) -> None:
621
        from vllm.distributed.parallel_state import get_ep_group
622

623
        if get_ep_group().rank == 0:
624
625
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
626
627
628
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
629
            global_expert_loads=global_expert_loads,
630
631
            rank_mapping=rank_mapping,
        )
632
633
634
635
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
636
637
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
638
639
640
641
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
642
643
644
645
646
647
648
649
650
651
652
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
653
                reconfig_request.new_data_parallel_rank_local
654
655
            )
        parallel_config.data_parallel_master_ip = (
656
            reconfig_request.new_data_parallel_master_ip
657
658
        )
        parallel_config.data_parallel_master_port = (
659
            reconfig_request.new_data_parallel_master_port
660
        )
661

662
663
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
664
    ) -> torch.Tensor | None:
665
666
667
668
669
670
671
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
672
673
674
675
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
676
677
678
679
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
680
681

        parallel_config = self.vllm_config.parallel_config
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
703
                    pcp_size_=get_pcp_group().world_size,
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

727
728
729
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
730
            new_physical_experts = (
731
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
732
            )
733
            parallel_config.eplb_config.num_redundant_experts = (
734
735
736
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
737
            global_expert_loads = None
738
        else:
739
740
741
742
743
744
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
745
746
747
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
748
749
            global_expert_loads = self.model_runner.eplb_state.rearrange(
                execute_shuffle=False
750
            )
751
            parallel_config.eplb_config.num_redundant_experts = (
752
                new_physical_experts - global_expert_loads[0].shape[1]
753
            )
754
        prepare_communication_buffer_for_model(self.model_runner.model)
755
756
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
757
758
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
759
760
            num_local_physical_experts=num_local_physical_experts,
        )
761
        return global_expert_loads
762
763

    def reinitialize_distributed(
764
765
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
766
767
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
768
769
770
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
771
772
773

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
774
775
776
777
778
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
779
780
781
782
783
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

784
785
786
787
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
788
789
790
791
792
793
794
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
795
796
797
798
799
800
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
801

802
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
803
804

        if new_ep_size > old_ep_size:
805
806
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
807

808
809
810
    def save_sharded_state(
        self,
        path: str,
811
812
        pattern: str | None = None,
        max_size: int | None = None,
813
    ) -> None:
814
        from vllm.model_executor.model_loader import ShardedStateLoader
815

816
817
818
819
820
821
822
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

823
824
825
826
827
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
828
829
            tensorizer_config=tensorizer_config,
        )
830

831
    def shutdown(self) -> None:
832
833
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
834
835
        if self.profiler is not None:
            self.profiler.shutdown()
836

837
838

def init_worker_distributed_environment(
839
    vllm_config: VllmConfig,
840
    rank: int,
841
    distributed_init_method: str | None = None,
842
    local_rank: int = -1,
843
    backend: str = "nccl",
844
845
) -> None:
    """Initialize the distributed environment."""
846
    parallel_config = vllm_config.parallel_config
847
848
849
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
850
851
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

852
853
854
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
855

856
857
858
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
859
        parallel_config.prefill_context_parallel_size,
860
861
        parallel_config.decode_context_parallel_size,
    )
862
863
864
865

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)