gpu_worker.py 35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any
11
12
13

import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
29
30
31
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
from vllm.model_executor import set_random_seed
35
from vllm.model_executor.models.interfaces import is_mixture_of_experts
36
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
37
from vllm.platforms import current_platform
38
from vllm.sequence import IntermediateTensors
39
from vllm.tasks import SupportedTask
40
41
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
42
from vllm.v1.core.sched.output import GrammarOutput
43
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
44
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
45
46
47
48
49
50
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
51
from vllm.v1.utils import report_usage_stats
52
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
53
from vllm.v1.worker.utils import is_residual_scattered_for_sp
54
from vllm.v1.worker.worker_base import WorkerBase
55
56
57
58

logger = init_logger(__name__)

if TYPE_CHECKING:
59
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
60
    from vllm.v1.core.sched.output import SchedulerOutput
61
62


63
class Worker(WorkerBase):
64
65
    def __init__(
        self,
66
        vllm_config: VllmConfig,
67
68
69
        local_rank: int,
        rank: int,
        distributed_init_method: str,
70
        is_driver_worker: bool = False,
71
    ):
72
73
74
75
76
77
78
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
79
80
81

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
82
            from vllm.utils.import_utils import init_cached_hf_modules
83

84
85
            init_cached_hf_modules()

86
87
88
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

89
90
91
92
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
93
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
94
95
96
97
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
98
99
100
101
102
103
104
105
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
106
107
108
109
110
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
111
112
113
114
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
115
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
116
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
117
118
                ),
            )
119
120
        else:
            self.profiler = None
121

122
    def sleep(self, level: int = 1) -> None:
123
124
        from vllm.device_allocator.cumem import CuMemAllocator

125
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
126
127
128
129
130

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
131
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
132
133
            }

134
        allocator = CuMemAllocator.get_instance()
135
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
136
137
138
139
140
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
141
142
143
144
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
145

146
    def wake_up(self, tags: list[str] | None = None) -> None:
147
148
        from vllm.device_allocator.cumem import CuMemAllocator

149
        allocator = CuMemAllocator.get_instance()
150
        allocator.wake_up(tags)
151

152
153
154
155
156
157
158
159
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

160
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
161
162
163
164
165
166
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
167
168
                    "Sleep mode can only be used for one instance per process."
                )
169
170
171
172
173
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

174
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
175
176
177
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

178
    def init_device(self):
179
180
181
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )

205
            self.device = torch.device(f"cuda:{self.local_rank}")
206
            current_platform.set_device(self.device)
207

208
            current_platform.check_if_supports_dtype(self.model_config.dtype)
209
210
211
212
213

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
214
215
216
217
218
219
220
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
221
222
223
224
225

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
226
227
            gc.collect()
            torch.cuda.empty_cache()
228
229
230

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
231
232
233
234
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
235
            if self.init_snapshot.free_memory < self.requested_memory:
236
237
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
238
239
240
241
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
242
                    f"({self.cache_config.gpu_memory_utilization}, "
243
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
244
245
                    f"utilization or reduce GPU memory used by other processes."
                )
246
        else:
247
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
248

249
        # Construct the model runner
250
        self.model_runner: GPUModelRunner = GPUModelRunner(
251
252
            self.vllm_config, self.device
        )
253

254
255
256
257
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

258
259
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
260
    def load_model(self) -> None:
261
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
262
        with self._maybe_get_memory_pool_context(tag="weights"):
263
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
264

265
266
267
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

268
    def reload_weights(self) -> None:
269
        self.model_runner.reload_weights()
270

271
    @torch.inference_mode()
272
    def determine_available_memory(self) -> int:
273
        """Profiles the peak memory usage of the model to determine how much
274
        memory can be used for KV cache without OOMs.
275
276

        The engine will first conduct a profiling of the existing memory usage.
277
        Then, it calculates the free memory that can be used for KV cache in
278
        bytes.
279

280
281
282
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
283
        """
284
285
286
287
288
289
290
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
291
292
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
293
                "KV Cache as specified by kv_cache_memory_bytes config and "
294
                "skipped memory profiling. This does not respect the "
295
296
297
298
299
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
300
301
                "correspondingly."
            )
302
303
304
            logger.info(msg)
            return kv_cache_memory_bytes

305
        torch.cuda.empty_cache()
306
        torch.cuda.reset_peak_memory_stats()
307
308
309

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
310
        with memory_profiling(
311
312
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
313
        ) as profile_result:
314
            self.model_runner.profile_run()
315

316
317
318
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

319
        free_gpu_memory = profile_result.after_profile.free_memory
320
321
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
322
        assert self.init_snapshot.free_memory > free_gpu_memory, (
323
            "Error in memory profiling. "
324
325
326
327
328
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
329
330
331
332
333
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
334

335
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
336
        logger.debug(
337
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
338
339
340
341
342
343
344
345
346
347
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
348
        logger.debug(profile_result)
349
        logger.info_once(
350
351
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
352
            scope="local",
353
        )
354
        gc.collect()
355

356
        return int(self.available_kv_cache_memory_bytes)
357

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

373
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
374
375
        return self.model_runner.get_kv_cache_spec()

376
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
377
        """Allocate GPU KV cache with the specified kv_cache_config."""
378

379
380
381
382
383
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
384
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
385

386
        if self.vllm_config.model_config.enable_sleep_mode:
387
388
            from vllm.device_allocator.cumem import CuMemAllocator

389
390
391
392
393
394
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
395
396

    def compile_or_warm_up_model(self) -> None:
397
398
399
400
401
402
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
403
404
405
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
406
            ]
407
        # We skip EPLB here since we don't want to record dummy metrics
408
409
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
410
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
411
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
412

413
414
415
416
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

417
        cuda_graph_memory_bytes = 0
418
        if not self.model_config.enforce_eager:
419
420
            cuda_graph_memory_bytes = self.model_runner.capture_model()

421
422
423
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
424
425
426
427
428
429
430
431
432
433
434
435
436
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
437
438
439
440
441
442
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
443
            kv_cache_memory_bytes_to_gpu_limit = (
444
445
446
447
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
448
            kv_cache_memory_bytes_to_requested_limit = (
449
450
451
452
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
453
454
455
456
457
458
459
460
461
462
463
464
465
466

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
467
468
469
470
471
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
472
                f"utilize gpu memory. Current kv cache memory in use is "
473
474
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
475

476
            logger.debug(msg)
477
478
479
480
481
482

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
483
        if get_pp_group().is_last_rank:
484
485
486
487
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
488

489
            # We skip EPLB here since we don't want to record dummy metrics
490
491
492
493
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
494
495
496
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
497
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
498

499
500
501
502
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

503
504
505
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

506
507
508
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

509
510
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
511

512
513
514
515
516
517
518
519
520
521
522
523
524
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

        return torch.profiler.record_function(
            f"execute_new_{num_new}_cached_{num_cached}"
        )

525
526
    @torch.inference_mode()
    def sample_tokens(
527
        self, grammar_output: "GrammarOutput | None"
528
529
530
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

531
532
    @torch.inference_mode()
    def execute_model(
533
534
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
535
        intermediate_tensors = None
536
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
537
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
538
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
539
        all_gather_tensors = {
540
541
542
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
543
        }
544
        if forward_pass and not get_pp_group().is_first_rank:
545
546
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
547
                    all_gather_group=get_tp_group(),
548
549
550
                    all_gather_tensors=all_gather_tensors,
                )
            )
551

552
553
554
555
556
557
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
558

559
        assert isinstance(output, IntermediateTensors)
560
        parallel_config = self.vllm_config.parallel_config
561
        assert (
562
            parallel_config.distributed_executor_backend != "external_launcher"
563
564
            and not get_pp_group().is_last_rank
        )
565

566
567
568
569
570
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
571
572
573
574
575
576
577

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
578
        if kv_connector_output.is_empty():
579
            return EMPTY_MODEL_RUNNER_OUTPUT
580

581
582
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
583
        return output
584

585
    def take_draft_token_ids(self) -> DraftTokenIds | None:
586
587
        return self.model_runner.take_draft_token_ids()

588
    def profile(self, is_start: bool = True):
589
590
591
592
593
594
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
595
596
            # only print profiler results on rank 0
            if self.local_rank == 0:
597
598
599
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
600

601
    def execute_dummy_batch(self) -> None:
602
        self.model_runner._dummy_run(1, uniform_decode=True)
603

604
605
606
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

607
608
609
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

610
    def list_loras(self) -> set[int]:
611
612
613
614
615
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

616
617
618
619
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

620
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
621
        from vllm.distributed.parallel_state import get_ep_group
622

623
        if get_ep_group().rank == 0:
624
625
626
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
627
628
629
630
631
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
632
633
634
635
636
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
637
638
639
640
641
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
642
643
644
        self,
        old_ep_size: int,
        new_ep_size: int,
645
        global_expert_loads: list[torch.Tensor] | None,
646
    ) -> None:
647
        from vllm.distributed.parallel_state import get_ep_group
648

649
        if get_ep_group().rank == 0:
650
651
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
652
653
654
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
655
            global_expert_loads=global_expert_loads,
656
657
            rank_mapping=rank_mapping,
        )
658
659
660
661
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
662
663
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
664
665
666
667
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
668
669
670
671
672
673
674
675
676
677
678
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
679
                reconfig_request.new_data_parallel_rank_local
680
681
            )
        parallel_config.data_parallel_master_ip = (
682
            reconfig_request.new_data_parallel_master_ip
683
684
        )
        parallel_config.data_parallel_master_port = (
685
            reconfig_request.new_data_parallel_master_port
686
        )
687

688
689
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
690
    ) -> torch.Tensor | None:
691
692
693
694
695
696
697
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
698
699
700
701
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
702
703
704
705
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
706
707

        parallel_config = self.vllm_config.parallel_config
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

752
753
754
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
755
            new_physical_experts = (
756
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
757
            )
758
            parallel_config.eplb_config.num_redundant_experts = (
759
760
761
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
762
            global_expert_loads = None
763
        else:
764
765
766
767
768
769
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
770
771
772
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
773
774
            global_expert_loads = self.model_runner.eplb_state.rearrange(
                execute_shuffle=False
775
            )
776
            parallel_config.eplb_config.num_redundant_experts = (
777
                new_physical_experts - global_expert_loads[0].shape[1]
778
            )
779
        prepare_communication_buffer_for_model(self.model_runner.model)
780
781
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
782
783
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
784
785
            num_local_physical_experts=num_local_physical_experts,
        )
786
        return global_expert_loads
787
788

    def reinitialize_distributed(
789
790
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
791
792
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
793
794
795
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
796
797
798

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
799
800
801
802
803
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
804
805
806
807
808
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

809
810
811
812
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
813
814
815
816
817
818
819
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
820
821
822
823
824
825
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
826

827
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
828
829

        if new_ep_size > old_ep_size:
830
831
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
832

833
834
835
    def save_sharded_state(
        self,
        path: str,
836
837
        pattern: str | None = None,
        max_size: int | None = None,
838
    ) -> None:
839
        from vllm.model_executor.model_loader import ShardedStateLoader
840

841
842
843
844
845
846
847
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

848
849
850
851
852
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
853
854
            tensorizer_config=tensorizer_config,
        )
855

856
    def shutdown(self) -> None:
857
858
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
859

860
861

def init_worker_distributed_environment(
862
    vllm_config: VllmConfig,
863
    rank: int,
864
    distributed_init_method: str | None = None,
865
    local_rank: int = -1,
866
    backend: str = "nccl",
867
868
) -> None:
    """Initialize the distributed environment."""
869
    parallel_config = vllm_config.parallel_config
870
871
872
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
873
874
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

875
876
877
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
878

879
880
881
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
882
883
        parallel_config.decode_context_parallel_size,
    )