gpu_worker.py 35.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any
11
12
13

import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
29
30
31
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
from vllm.model_executor import set_random_seed
35
from vllm.model_executor.models.interfaces import is_mixture_of_experts
36
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
37
from vllm.platforms import current_platform
38
from vllm.profiler.gpu_profiler import CudaProfilerWrapper
39
from vllm.sequence import IntermediateTensors
40
from vllm.tasks import SupportedTask
41
42
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
43
from vllm.v1.core.sched.output import GrammarOutput
44
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
45
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
46
47
48
49
50
51
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
52
from vllm.v1.utils import report_usage_stats
53
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
54
from vllm.v1.worker.utils import is_residual_scattered_for_sp
55
from vllm.v1.worker.worker_base import WorkerBase
56
57
58
59

logger = init_logger(__name__)

if TYPE_CHECKING:
60
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
61
    from vllm.v1.core.sched.output import SchedulerOutput
62
63


64
class Worker(WorkerBase):
65
66
    def __init__(
        self,
67
        vllm_config: VllmConfig,
68
69
70
        local_rank: int,
        rank: int,
        distributed_init_method: str,
71
        is_driver_worker: bool = False,
72
    ):
73
74
75
76
77
78
79
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
80
81
82

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
83
            from vllm.utils.import_utils import init_cached_hf_modules
84

85
86
            init_cached_hf_modules()

87
88
89
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

90
91
92
93
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
94
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
95
96
97
98
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
99
100
101
102
103
104
105
106
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
107
108
109
110
111
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
112
113
114
115
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
116
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
117
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
118
119
                ),
            )
120
121
        elif envs.VLLM_TORCH_CUDA_PROFILE:
            self.profiler = CudaProfilerWrapper()
122
123
        else:
            self.profiler = None
124

125
    def sleep(self, level: int = 1) -> None:
126
127
        from vllm.device_allocator.cumem import CuMemAllocator

128
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
129
130
131
132
133

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
134
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
135
136
            }

137
        allocator = CuMemAllocator.get_instance()
138
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
139
140
141
142
143
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
144
145
146
147
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
148

149
    def wake_up(self, tags: list[str] | None = None) -> None:
150
151
        from vllm.device_allocator.cumem import CuMemAllocator

152
        allocator = CuMemAllocator.get_instance()
153
        allocator.wake_up(tags)
154

155
156
157
158
159
160
161
162
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

163
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
164
165
166
167
168
169
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
170
171
                    "Sleep mode can only be used for one instance per process."
                )
172
173
174
175
176
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

177
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
178
179
180
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

181
    def init_device(self):
182
183
184
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )

208
            self.device = torch.device(f"cuda:{self.local_rank}")
209
            current_platform.set_device(self.device)
210

211
            current_platform.check_if_supports_dtype(self.model_config.dtype)
212
213
214
215
216

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
217
218
219
220
221
222
223
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
224
225
226
227
228

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
229
230
            gc.collect()
            torch.cuda.empty_cache()
231
232
233

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
234
235
236
237
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
238
            if self.init_snapshot.free_memory < self.requested_memory:
239
240
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
241
242
243
244
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
245
                    f"({self.cache_config.gpu_memory_utilization}, "
246
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
247
248
                    f"utilization or reduce GPU memory used by other processes."
                )
249
        else:
250
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
251

252
        # Construct the model runner
253
        self.model_runner: GPUModelRunner = GPUModelRunner(
254
255
            self.vllm_config, self.device
        )
256

257
258
259
260
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

261
262
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
263
    def load_model(self) -> None:
264
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
265
        with self._maybe_get_memory_pool_context(tag="weights"):
266
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
267

268
269
270
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

271
    def reload_weights(self) -> None:
272
        self.model_runner.reload_weights()
273

274
    @torch.inference_mode()
275
    def determine_available_memory(self) -> int:
276
        """Profiles the peak memory usage of the model to determine how much
277
        memory can be used for KV cache without OOMs.
278
279

        The engine will first conduct a profiling of the existing memory usage.
280
        Then, it calculates the free memory that can be used for KV cache in
281
        bytes.
282

283
284
285
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
286
        """
287
288
289
290
291
292
293
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
294
295
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
296
                "KV Cache as specified by kv_cache_memory_bytes config and "
297
                "skipped memory profiling. This does not respect the "
298
299
300
301
302
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
303
304
                "correspondingly."
            )
305
306
307
            logger.info(msg)
            return kv_cache_memory_bytes

308
        torch.cuda.empty_cache()
309
        torch.cuda.reset_peak_memory_stats()
310
311
312

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
313
        with memory_profiling(
314
315
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
316
        ) as profile_result:
317
            self.model_runner.profile_run()
318

319
320
321
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

322
        free_gpu_memory = profile_result.after_profile.free_memory
323
324
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
325
        assert self.init_snapshot.free_memory > free_gpu_memory, (
326
            "Error in memory profiling. "
327
328
329
330
331
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
332
333
334
335
336
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
337

338
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
339
        logger.debug(
340
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
341
342
343
344
345
346
347
348
349
350
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
351
        logger.debug(profile_result)
352
        logger.info_once(
353
354
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
355
            scope="local",
356
        )
357
        gc.collect()
358

359
        return int(self.available_kv_cache_memory_bytes)
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

376
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
377
378
        return self.model_runner.get_kv_cache_spec()

379
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
380
        """Allocate GPU KV cache with the specified kv_cache_config."""
381

382
383
384
385
386
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
387
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
388

389
        if self.vllm_config.model_config.enable_sleep_mode:
390
391
            from vllm.device_allocator.cumem import CuMemAllocator

392
393
394
395
396
397
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
398
399

    def compile_or_warm_up_model(self) -> None:
400
401
402
403
404
405
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
406
407
408
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
409
            ]
410
        # We skip EPLB here since we don't want to record dummy metrics
411
412
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
413
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
414
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
415

416
417
418
419
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

420
        cuda_graph_memory_bytes = 0
421
        if not self.model_config.enforce_eager:
422
423
            cuda_graph_memory_bytes = self.model_runner.capture_model()

424
425
426
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
427
428
429
430
431
432
433
434
435
436
437
438
439
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
440
441
442
443
444
445
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
446
            kv_cache_memory_bytes_to_gpu_limit = (
447
448
449
450
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
451
            kv_cache_memory_bytes_to_requested_limit = (
452
453
454
455
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
456
457
458
459
460
461
462
463
464
465
466
467
468
469

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
470
471
472
473
474
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
475
                f"utilize gpu memory. Current kv cache memory in use is "
476
477
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
478

479
            logger.debug(msg)
480
481
482
483
484
485

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
486
        if get_pp_group().is_last_rank:
487
488
489
490
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
491

492
            # We skip EPLB here since we don't want to record dummy metrics
493
494
495
496
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
497
498
499
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
500
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
501

502
503
504
505
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

506
507
508
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

509
510
511
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

512
513
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
514

515
516
517
518
519
520
521
522
523
524
525
526
527
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

        return torch.profiler.record_function(
            f"execute_new_{num_new}_cached_{num_cached}"
        )

528
529
    @torch.inference_mode()
    def sample_tokens(
530
        self, grammar_output: "GrammarOutput | None"
531
532
533
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

534
535
    @torch.inference_mode()
    def execute_model(
536
537
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
538
        intermediate_tensors = None
539
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
540
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
541
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
542
        all_gather_tensors = {
543
544
545
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
546
        }
547
        if forward_pass and not get_pp_group().is_first_rank:
548
549
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
550
                    all_gather_group=get_tp_group(),
551
552
553
                    all_gather_tensors=all_gather_tensors,
                )
            )
554

555
556
557
558
559
560
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
561

562
        assert isinstance(output, IntermediateTensors)
563
        parallel_config = self.vllm_config.parallel_config
564
        assert (
565
            parallel_config.distributed_executor_backend != "external_launcher"
566
567
            and not get_pp_group().is_last_rank
        )
568

569
570
571
572
573
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
574
575
576
577
578
579
580

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
581
        if kv_connector_output.is_empty():
582
            return EMPTY_MODEL_RUNNER_OUTPUT
583

584
585
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
586
        return output
587

588
    def take_draft_token_ids(self) -> DraftTokenIds | None:
589
590
        return self.model_runner.take_draft_token_ids()

591
    def profile(self, is_start: bool = True):
592
593
594
595
596
597
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
598
            # only print profiler results on rank 0
599
600
601
602
            if (
                isinstance(self.profiler, torch.profiler.profile)
                and self.local_rank == 0
            ):
603
604
605
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
606

607
    def execute_dummy_batch(self) -> None:
608
        self.model_runner._dummy_run(1, uniform_decode=True)
609

610
611
612
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

613
614
615
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

616
    def list_loras(self) -> set[int]:
617
618
619
620
621
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

622
623
624
625
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

626
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
627
        from vllm.distributed.parallel_state import get_ep_group
628

629
        if get_ep_group().rank == 0:
630
631
632
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
633
634
635
636
637
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
638
639
640
641
642
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
643
644
645
646
647
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
648
649
650
        self,
        old_ep_size: int,
        new_ep_size: int,
651
        global_expert_loads: list[torch.Tensor] | None,
652
    ) -> None:
653
        from vllm.distributed.parallel_state import get_ep_group
654

655
        if get_ep_group().rank == 0:
656
657
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
658
659
660
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
661
            global_expert_loads=global_expert_loads,
662
663
            rank_mapping=rank_mapping,
        )
664
665
666
667
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
668
669
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
670
671
672
673
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
674
675
676
677
678
679
680
681
682
683
684
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
685
                reconfig_request.new_data_parallel_rank_local
686
687
            )
        parallel_config.data_parallel_master_ip = (
688
            reconfig_request.new_data_parallel_master_ip
689
690
        )
        parallel_config.data_parallel_master_port = (
691
            reconfig_request.new_data_parallel_master_port
692
        )
693

694
695
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
696
    ) -> torch.Tensor | None:
697
698
699
700
701
702
703
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
704
705
706
707
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
708
709
710
711
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
712
713

        parallel_config = self.vllm_config.parallel_config
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

758
759
760
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
761
            new_physical_experts = (
762
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
763
            )
764
            parallel_config.eplb_config.num_redundant_experts = (
765
766
767
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
768
            global_expert_loads = None
769
        else:
770
771
772
773
774
775
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
776
777
778
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
779
780
            global_expert_loads = self.model_runner.eplb_state.rearrange(
                execute_shuffle=False
781
            )
782
            parallel_config.eplb_config.num_redundant_experts = (
783
                new_physical_experts - global_expert_loads[0].shape[1]
784
            )
785
        prepare_communication_buffer_for_model(self.model_runner.model)
786
787
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
788
789
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
790
791
            num_local_physical_experts=num_local_physical_experts,
        )
792
        return global_expert_loads
793
794

    def reinitialize_distributed(
795
796
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
797
798
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
799
800
801
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
802
803
804

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
805
806
807
808
809
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
810
811
812
813
814
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

815
816
817
818
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
819
820
821
822
823
824
825
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
826
827
828
829
830
831
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
832

833
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
834
835

        if new_ep_size > old_ep_size:
836
837
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
838

839
840
841
    def save_sharded_state(
        self,
        path: str,
842
843
        pattern: str | None = None,
        max_size: int | None = None,
844
    ) -> None:
845
        from vllm.model_executor.model_loader import ShardedStateLoader
846

847
848
849
850
851
852
853
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

854
855
856
857
858
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
859
860
            tensorizer_config=tensorizer_config,
        )
861

862
    def shutdown(self) -> None:
863
864
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
865

866
867

def init_worker_distributed_environment(
868
    vllm_config: VllmConfig,
869
    rank: int,
870
    distributed_init_method: str | None = None,
871
    local_rank: int = -1,
872
    backend: str = "nccl",
873
874
) -> None:
    """Initialize the distributed environment."""
875
    parallel_config = vllm_config.parallel_config
876
877
878
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
879
880
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

881
882
883
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
884

885
886
887
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
888
889
        parallel_config.decode_context_parallel_size,
    )