gpu_worker.py 33.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from typing import TYPE_CHECKING, Any
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
23
24
25
26
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
27
28
29
30
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
31
from vllm.logger import init_logger
32
from vllm.lora.request import LoRARequest
33
from vllm.model_executor import set_random_seed
34
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
35
from vllm.platforms import current_platform
36
from vllm.sequence import IntermediateTensors
37
from vllm.tasks import SupportedTask
38
39
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
40
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
41
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
42
43
44
45
46
47
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
48
from vllm.v1.utils import report_usage_stats
49
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
50
from vllm.v1.worker.utils import is_residual_scattered_for_sp
51
from vllm.v1.worker.worker_base import WorkerBase
52
53
54
55

logger = init_logger(__name__)

if TYPE_CHECKING:
56
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
57
    from vllm.v1.core.sched.output import SchedulerOutput
58
59


60
class Worker(WorkerBase):
61
62
    def __init__(
        self,
63
        vllm_config: VllmConfig,
64
65
66
        local_rank: int,
        rank: int,
        distributed_init_method: str,
67
        is_driver_worker: bool = False,
68
    ):
69
70
71
72
73
74
75
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
76
77
78

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
79
            from vllm.utils.import_utils import init_cached_hf_modules
80

81
82
            init_cached_hf_modules()

83
84
85
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

86
87
88
89
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
90
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
91
92
93
94
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
95
96
97
98
99
100
101
102
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
103
104
105
106
107
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
108
109
110
111
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
112
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
113
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
114
115
                ),
            )
116
117
        else:
            self.profiler = None
118

119
    def sleep(self, level: int = 1) -> None:
120
121
        from vllm.device_allocator.cumem import CuMemAllocator

122
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
123
124
125
126
127

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
128
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
129
130
            }

131
        allocator = CuMemAllocator.get_instance()
132
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
133
134
135
136
137
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
138
139
140
141
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
142

143
    def wake_up(self, tags: list[str] | None = None) -> None:
144
145
        from vllm.device_allocator.cumem import CuMemAllocator

146
        allocator = CuMemAllocator.get_instance()
147
        allocator.wake_up(tags)
148

149
150
151
152
153
154
155
156
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

157
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
158
159
160
161
162
163
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
164
165
                    "Sleep mode can only be used for one instance per process."
                )
166
167
168
169
170
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

171
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
172
173
174
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

175
    def init_device(self):
176
177
178
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )

202
            self.device = torch.device(f"cuda:{self.local_rank}")
203
            current_platform.set_device(self.device)
204

205
            current_platform.check_if_supports_dtype(self.model_config.dtype)
206
207
208
209
210

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
211
212
213
214
215
216
217
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
218
219
220
221
222

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
223
224
            gc.collect()
            torch.cuda.empty_cache()
225
226
227

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
228
229
230
231
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
232
            if self.init_snapshot.free_memory < self.requested_memory:
233
234
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
235
236
237
238
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
239
                    f"({self.cache_config.gpu_memory_utilization}, "
240
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
241
242
                    f"utilization or reduce GPU memory used by other processes."
                )
243
        else:
244
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
245

246
        # Construct the model runner
247
        self.model_runner: GPUModelRunner = GPUModelRunner(
248
249
            self.vllm_config, self.device
        )
250

251
252
253
254
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

255
256
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
257
    def load_model(self) -> None:
258
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
259
        with self._maybe_get_memory_pool_context(tag="weights"):
260
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
261

262
263
264
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

265
    def reload_weights(self) -> None:
266
        self.model_runner.reload_weights()
267

268
    @torch.inference_mode()
269
    def determine_available_memory(self) -> int:
270
        """Profiles the peak memory usage of the model to determine how much
271
        memory can be used for KV cache without OOMs.
272
273

        The engine will first conduct a profiling of the existing memory usage.
274
        Then, it calculates the free memory that can be used for KV cache in
275
        bytes.
276

277
278
279
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
280
        """
281
282
283
284
285
286
287
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
288
289
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
290
                "KV Cache as specified by kv_cache_memory_bytes config and "
291
                "skipped memory profiling. This does not respect the "
292
293
294
295
296
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
297
298
                "correspondingly."
            )
299
300
301
            logger.info(msg)
            return kv_cache_memory_bytes

302
        torch.cuda.empty_cache()
303
        torch.cuda.reset_peak_memory_stats()
304
305
306

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
307
        with memory_profiling(
308
309
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
310
        ) as profile_result:
311
            self.model_runner.profile_run()
312

313
314
315
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

316
        free_gpu_memory = profile_result.after_profile.free_memory
317
318
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
319
        assert self.init_snapshot.free_memory > free_gpu_memory, (
320
            "Error in memory profiling. "
321
322
323
324
325
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
326
327
328
329
330
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
331

332
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
333
        logger.debug(
334
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
335
336
337
338
339
340
341
342
343
344
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
345
        logger.debug(profile_result)
346
        logger.info_once(
347
348
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
349
            scope="local",
350
        )
351
        gc.collect()
352

353
        return int(self.available_kv_cache_memory_bytes)
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

370
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
371
372
        return self.model_runner.get_kv_cache_spec()

373
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
374
        """Allocate GPU KV cache with the specified kv_cache_config."""
375

376
377
378
379
380
381
382
383
384
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
        connector_vllm_config = copy.copy(self.vllm_config)
        connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
        ensure_kv_transfer_initialized(connector_vllm_config)

385
        if self.vllm_config.model_config.enable_sleep_mode:
386
387
            from vllm.device_allocator.cumem import CuMemAllocator

388
389
390
391
392
393
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
394
395

    def compile_or_warm_up_model(self) -> None:
396
397
398
399
400
401
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
402
403
404
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
405
            ]
406
        # We skip EPLB here since we don't want to record dummy metrics
407
408
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
409
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
410
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
411

412
413
414
415
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

416
        cuda_graph_memory_bytes = 0
417
        if not self.model_config.enforce_eager:
418
419
            cuda_graph_memory_bytes = self.model_runner.capture_model()

420
421
422
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
423
424
425
426
427
428
429
430
431
432
433
434
435
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
436
437
438
439
440
441
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
442
            kv_cache_memory_bytes_to_gpu_limit = (
443
444
445
446
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
447
            kv_cache_memory_bytes_to_requested_limit = (
448
449
450
451
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
452
453
454
455
456
457
458
459
460
461
462
463
464
465

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
466
467
468
469
470
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
471
                f"utilize gpu memory. Current kv cache memory in use is "
472
473
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
474

475
            logger.debug(msg)
476
477
478
479
480
481

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
482
        if get_pp_group().is_last_rank:
483
484
485
486
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
487

488
            # We skip EPLB here since we don't want to record dummy metrics
489
490
491
492
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
493
494
495
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
496
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
497

498
499
500
501
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

502
503
504
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

505
506
507
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

508
509
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
510

511
512
513
514
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
515
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
516
        intermediate_tensors = None
517
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
518
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
519
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
520
        all_gather_tensors = {
521
522
523
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
524
        }
525
        if forward_pass and not get_pp_group().is_first_rank:
526
527
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
528
                    all_gather_group=get_tp_group(),
529
530
531
                    all_gather_tensors=all_gather_tensors,
                )
            )
532

533
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
534
        if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
535
            return output
536

537
        assert isinstance(output, IntermediateTensors)
538
        parallel_config = self.vllm_config.parallel_config
539
540
541
542
        assert (
            parallel_config.distributed_executor_backend != ("external_launcher")
            and not get_pp_group().is_last_rank
        )
543

544
545
546
547
548
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
549
550
551
552
553
554
555

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
556
        if kv_connector_output.is_empty():
557
            return EMPTY_MODEL_RUNNER_OUTPUT
558

559
560
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
561
        return output
562

563
    def take_draft_token_ids(self) -> DraftTokenIds | None:
564
565
        return self.model_runner.take_draft_token_ids()

566
    def profile(self, is_start: bool = True):
567
568
569
570
571
572
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
573
574
            # only print profiler results on rank 0
            if self.local_rank == 0:
575
576
577
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
578

579
    def execute_dummy_batch(self) -> None:
580
        self.model_runner._dummy_run(1, uniform_decode=True)
581

582
583
584
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

585
586
587
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

588
    def list_loras(self) -> set[int]:
589
590
591
592
593
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

594
595
596
597
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

598
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
599
        from vllm.distributed.parallel_state import get_ep_group
600

601
        if get_ep_group().rank == 0:
602
603
604
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
605
606
607
608
609
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
610
611
612
613
614
615
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
616
617
618
619
620
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
621
622
623
        self,
        old_ep_size: int,
        new_ep_size: int,
624
        global_expert_load: torch.Tensor | None,
625
    ) -> None:
626
        from vllm.distributed.parallel_state import get_ep_group
627

628
        if get_ep_group().rank == 0:
629
630
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
631
632
633
634
635
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
636
637
            rank_mapping=rank_mapping,
        )
638
639
640
641
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
642
643
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
644
645
646
647
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
648
649
650
651
652
653
654
655
656
657
658
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
659
                reconfig_request.new_data_parallel_rank_local
660
661
            )
        parallel_config.data_parallel_master_ip = (
662
            reconfig_request.new_data_parallel_master_ip
663
664
        )
        parallel_config.data_parallel_master_port = (
665
            reconfig_request.new_data_parallel_master_port
666
        )
667

668
669
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
670
    ) -> torch.Tensor | None:
671
672
673
674
675
676
677
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
678
679
680
681
682
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
683
684
685

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
686
687
688
689
690
691
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
692
693
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
694
695
696
697
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
698
699
700
701
702
703
704
705
706
707
708
709
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
710
            new_physical_experts = (
711
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
712
            )
713
            parallel_config.eplb_config.num_redundant_experts = (
714
715
716
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
717
718
            global_expert_load = None
        else:
719
720
721
722
723
724
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
725
726
727
728
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
729
730
                self.model_runner.model, execute_shuffle=False
            )
731
            parallel_config.eplb_config.num_redundant_experts = (
732
733
                new_physical_experts - global_expert_load.shape[1]
            )
734
735
736
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
737
738
            num_local_physical_experts=num_local_physical_experts,
        )
739
740
741
        return global_expert_load

    def reinitialize_distributed(
742
743
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
744
745
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
746
747
748
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
749
750
751

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
752
753
754
755
756
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
757
758
759
760
761
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

762
763
764
765
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
766
767
768
769
770
771
772
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
773
774
775
776
777
778
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
779
780
781
782
783

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
784
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
785

786
787
788
    def save_sharded_state(
        self,
        path: str,
789
790
        pattern: str | None = None,
        max_size: int | None = None,
791
    ) -> None:
792
        from vllm.model_executor.model_loader import ShardedStateLoader
793

794
795
796
797
798
799
800
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

801
802
803
804
805
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
806
807
            tensorizer_config=tensorizer_config,
        )
808

809
    def shutdown(self) -> None:
810
811
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
812

813
814

def init_worker_distributed_environment(
815
    vllm_config: VllmConfig,
816
    rank: int,
817
    distributed_init_method: str | None = None,
818
    local_rank: int = -1,
819
    backend: str = "nccl",
820
821
) -> None:
    """Initialize the distributed environment."""
822
    parallel_config = vllm_config.parallel_config
823
824
825
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
826
827
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

828
829
830
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
831

832
833
834
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
835
836
        parallel_config.decode_context_parallel_size,
    )