gpu_worker.py 33.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any
11
12
13

import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
29
30
31
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
from vllm.model_executor import set_random_seed
35
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
36
from vllm.platforms import current_platform
37
from vllm.sequence import IntermediateTensors
38
from vllm.tasks import SupportedTask
39
40
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
41
from vllm.v1.core.sched.output import GrammarOutput
42
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
43
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
44
45
46
47
48
49
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
50
from vllm.v1.utils import report_usage_stats
51
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
52
from vllm.v1.worker.utils import is_residual_scattered_for_sp
53
from vllm.v1.worker.worker_base import WorkerBase
54
55
56
57

logger = init_logger(__name__)

if TYPE_CHECKING:
58
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
59
    from vllm.v1.core.sched.output import SchedulerOutput
60
61


62
class Worker(WorkerBase):
63
64
    def __init__(
        self,
65
        vllm_config: VllmConfig,
66
67
68
        local_rank: int,
        rank: int,
        distributed_init_method: str,
69
        is_driver_worker: bool = False,
70
    ):
71
72
73
74
75
76
77
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
78
79
80

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
81
            from vllm.utils.import_utils import init_cached_hf_modules
82

83
84
            init_cached_hf_modules()

85
86
87
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

88
89
90
91
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
92
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
93
94
95
96
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
97
98
99
100
101
102
103
104
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
105
106
107
108
109
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
110
111
112
113
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
114
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
115
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
116
117
                ),
            )
118
119
        else:
            self.profiler = None
120

121
    def sleep(self, level: int = 1) -> None:
122
123
        from vllm.device_allocator.cumem import CuMemAllocator

124
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
125
126
127
128
129

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
130
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
131
132
            }

133
        allocator = CuMemAllocator.get_instance()
134
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
135
136
137
138
139
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
140
141
142
143
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
144

145
    def wake_up(self, tags: list[str] | None = None) -> None:
146
147
        from vllm.device_allocator.cumem import CuMemAllocator

148
        allocator = CuMemAllocator.get_instance()
149
        allocator.wake_up(tags)
150

151
152
153
154
155
156
157
158
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

159
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
160
161
162
163
164
165
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
166
167
                    "Sleep mode can only be used for one instance per process."
                )
168
169
170
171
172
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

173
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
174
175
176
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

177
    def init_device(self):
178
179
180
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )

204
            self.device = torch.device(f"cuda:{self.local_rank}")
205
            current_platform.set_device(self.device)
206

207
            current_platform.check_if_supports_dtype(self.model_config.dtype)
208
209
210
211
212

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
213
214
215
216
217
218
219
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
220
221
222
223
224

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
225
226
            gc.collect()
            torch.cuda.empty_cache()
227
228
229

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
230
231
232
233
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
234
            if self.init_snapshot.free_memory < self.requested_memory:
235
236
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
237
238
239
240
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
241
                    f"({self.cache_config.gpu_memory_utilization}, "
242
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
243
244
                    f"utilization or reduce GPU memory used by other processes."
                )
245
        else:
246
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
247

248
        # Construct the model runner
249
        self.model_runner: GPUModelRunner = GPUModelRunner(
250
251
            self.vllm_config, self.device
        )
252

253
254
255
256
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

257
258
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
259
    def load_model(self) -> None:
260
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
261
        with self._maybe_get_memory_pool_context(tag="weights"):
262
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
263

264
265
266
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

267
    def reload_weights(self) -> None:
268
        self.model_runner.reload_weights()
269

270
    @torch.inference_mode()
271
    def determine_available_memory(self) -> int:
272
        """Profiles the peak memory usage of the model to determine how much
273
        memory can be used for KV cache without OOMs.
274
275

        The engine will first conduct a profiling of the existing memory usage.
276
        Then, it calculates the free memory that can be used for KV cache in
277
        bytes.
278

279
280
281
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
282
        """
283
284
285
286
287
288
289
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
290
291
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
292
                "KV Cache as specified by kv_cache_memory_bytes config and "
293
                "skipped memory profiling. This does not respect the "
294
295
296
297
298
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
299
300
                "correspondingly."
            )
301
302
303
            logger.info(msg)
            return kv_cache_memory_bytes

304
        torch.cuda.empty_cache()
305
        torch.cuda.reset_peak_memory_stats()
306
307
308

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
309
        with memory_profiling(
310
311
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
312
        ) as profile_result:
313
            self.model_runner.profile_run()
314

315
316
317
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

318
        free_gpu_memory = profile_result.after_profile.free_memory
319
320
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
321
        assert self.init_snapshot.free_memory > free_gpu_memory, (
322
            "Error in memory profiling. "
323
324
325
326
327
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
328
329
330
331
332
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
333

334
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
335
        logger.debug(
336
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
337
338
339
340
341
342
343
344
345
346
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
347
        logger.debug(profile_result)
348
        logger.info_once(
349
350
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
351
            scope="local",
352
        )
353
        gc.collect()
354

355
        return int(self.available_kv_cache_memory_bytes)
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

372
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
373
374
        return self.model_runner.get_kv_cache_spec()

375
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
376
        """Allocate GPU KV cache with the specified kv_cache_config."""
377

378
379
380
381
382
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
383
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
384

385
        if self.vllm_config.model_config.enable_sleep_mode:
386
387
            from vllm.device_allocator.cumem import CuMemAllocator

388
389
390
391
392
393
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
394
395

    def compile_or_warm_up_model(self) -> None:
396
397
398
399
400
401
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
402
403
404
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
405
            ]
406
        # We skip EPLB here since we don't want to record dummy metrics
407
408
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
409
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
410
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
411

412
413
414
415
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

416
        cuda_graph_memory_bytes = 0
417
        if not self.model_config.enforce_eager:
418
419
            cuda_graph_memory_bytes = self.model_runner.capture_model()

420
421
422
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
423
424
425
426
427
428
429
430
431
432
433
434
435
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
436
437
438
439
440
441
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
442
            kv_cache_memory_bytes_to_gpu_limit = (
443
444
445
446
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
447
            kv_cache_memory_bytes_to_requested_limit = (
448
449
450
451
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
452
453
454
455
456
457
458
459
460
461
462
463
464
465

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
466
467
468
469
470
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
471
                f"utilize gpu memory. Current kv cache memory in use is "
472
473
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
474

475
            logger.debug(msg)
476
477
478
479
480
481

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
482
        if get_pp_group().is_last_rank:
483
484
485
486
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
487

488
            # We skip EPLB here since we don't want to record dummy metrics
489
490
491
492
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
493
494
495
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
496
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
497

498
499
500
501
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

502
503
504
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

505
506
507
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

508
509
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
510

511
512
    @torch.inference_mode()
    def sample_tokens(
513
        self, grammar_output: "GrammarOutput | None"
514
515
516
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

517
518
    @torch.inference_mode()
    def execute_model(
519
520
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
521
        intermediate_tensors = None
522
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
523
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
524
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
525
        all_gather_tensors = {
526
527
528
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
529
        }
530
        if forward_pass and not get_pp_group().is_first_rank:
531
532
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
533
                    all_gather_group=get_tp_group(),
534
535
536
                    all_gather_tensors=all_gather_tensors,
                )
            )
537

538
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
539
        if isinstance(output, (ModelRunnerOutput, NoneType)):
540
            return output
541

542
        assert isinstance(output, IntermediateTensors)
543
        parallel_config = self.vllm_config.parallel_config
544
        assert (
545
            parallel_config.distributed_executor_backend != "external_launcher"
546
547
            and not get_pp_group().is_last_rank
        )
548

549
550
551
552
553
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
554
555
556
557
558
559
560

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
561
        if kv_connector_output.is_empty():
562
            return EMPTY_MODEL_RUNNER_OUTPUT
563

564
565
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
566
        return output
567

568
    def take_draft_token_ids(self) -> DraftTokenIds | None:
569
570
        return self.model_runner.take_draft_token_ids()

571
    def profile(self, is_start: bool = True):
572
573
574
575
576
577
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
578
579
            # only print profiler results on rank 0
            if self.local_rank == 0:
580
581
582
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
583

584
    def execute_dummy_batch(self) -> None:
585
        self.model_runner._dummy_run(1, uniform_decode=True)
586

587
588
589
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

590
591
592
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

593
    def list_loras(self) -> set[int]:
594
595
596
597
598
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

599
600
601
602
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

603
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
604
        from vllm.distributed.parallel_state import get_ep_group
605

606
        if get_ep_group().rank == 0:
607
608
609
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
610
611
612
613
614
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
615
616
617
618
619
620
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
621
622
623
624
625
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
626
627
628
        self,
        old_ep_size: int,
        new_ep_size: int,
629
        global_expert_load: torch.Tensor | None,
630
    ) -> None:
631
        from vllm.distributed.parallel_state import get_ep_group
632

633
        if get_ep_group().rank == 0:
634
635
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
636
637
638
639
640
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
641
642
            rank_mapping=rank_mapping,
        )
643
644
645
646
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
647
648
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
649
650
651
652
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
653
654
655
656
657
658
659
660
661
662
663
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
664
                reconfig_request.new_data_parallel_rank_local
665
666
            )
        parallel_config.data_parallel_master_ip = (
667
            reconfig_request.new_data_parallel_master_ip
668
669
        )
        parallel_config.data_parallel_master_port = (
670
            reconfig_request.new_data_parallel_master_port
671
        )
672

673
674
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
675
    ) -> torch.Tensor | None:
676
677
678
679
680
681
682
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
683
684
685
686
687
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
688
689
690

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
691
692
693
694
695
696
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
697
698
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
699
700
701
702
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
703
704
705
706
707
708
709
710
711
712
713
714
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
715
            new_physical_experts = (
716
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
717
            )
718
            parallel_config.eplb_config.num_redundant_experts = (
719
720
721
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
722
723
            global_expert_load = None
        else:
724
725
726
727
728
729
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
730
731
732
733
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
734
735
                self.model_runner.model, execute_shuffle=False
            )
736
            parallel_config.eplb_config.num_redundant_experts = (
737
738
                new_physical_experts - global_expert_load.shape[1]
            )
739
740
741
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
742
743
            num_local_physical_experts=num_local_physical_experts,
        )
744
745
746
        return global_expert_load

    def reinitialize_distributed(
747
748
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
749
750
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
751
752
753
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
754
755
756

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
757
758
759
760
761
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
762
763
764
765
766
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

767
768
769
770
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
771
772
773
774
775
776
777
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
778
779
780
781
782
783
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
784
785
786
787
788

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
789
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
790

791
792
793
    def save_sharded_state(
        self,
        path: str,
794
795
        pattern: str | None = None,
        max_size: int | None = None,
796
    ) -> None:
797
        from vllm.model_executor.model_loader import ShardedStateLoader
798

799
800
801
802
803
804
805
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

806
807
808
809
810
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
811
812
            tensorizer_config=tensorizer_config,
        )
813

814
    def shutdown(self) -> None:
815
816
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
817

818
819

def init_worker_distributed_environment(
820
    vllm_config: VllmConfig,
821
    rank: int,
822
    distributed_init_method: str | None = None,
823
    local_rank: int = -1,
824
    backend: str = "nccl",
825
826
) -> None:
    """Initialize the distributed environment."""
827
    parallel_config = vllm_config.parallel_config
828
829
830
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
831
832
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

833
834
835
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
836

837
838
839
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
840
841
        parallel_config.decode_context_parallel_size,
    )