gpu_worker.py 33.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
import copy
6
7
import gc
import os
8
from contextlib import AbstractContextManager, nullcontext
9
from types import NoneType
10
from typing import TYPE_CHECKING, Any
11
12
13

import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
29
30
31
from vllm.distributed.parallel_state import (
    get_pp_group,
    get_tp_group,
)
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
from vllm.model_executor import set_random_seed
35
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
36
from vllm.platforms import current_platform
37
from vllm.sequence import IntermediateTensors
38
from vllm.tasks import SupportedTask
39
40
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
41
from vllm.v1.core.sched.output import GrammarOutput
42
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
43
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
44
45
46
47
48
49
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
50
from vllm.v1.utils import report_usage_stats
51
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
52
from vllm.v1.worker.utils import is_residual_scattered_for_sp
53
from vllm.v1.worker.worker_base import WorkerBase
54
55
56
57

logger = init_logger(__name__)

if TYPE_CHECKING:
58
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
59
    from vllm.v1.core.sched.output import SchedulerOutput
60
61


62
class Worker(WorkerBase):
63
64
    def __init__(
        self,
65
        vllm_config: VllmConfig,
66
67
68
        local_rank: int,
        rank: int,
        distributed_init_method: str,
69
        is_driver_worker: bool = False,
70
    ):
71
72
73
74
75
76
77
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
78
79
80

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
81
            from vllm.utils.import_utils import init_cached_hf_modules
82

83
84
            init_cached_hf_modules()

85
86
87
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

88
89
90
91
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
92
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
93
94
95
96
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
97
98
99
100
101
102
103
104
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
105
106
107
108
109
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
110
111
112
113
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
114
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
115
                    torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
116
117
                ),
            )
118
119
        else:
            self.profiler = None
120

121
    def sleep(self, level: int = 1) -> None:
122
123
        from vllm.device_allocator.cumem import CuMemAllocator

124
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
125
126
127
128
129

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
130
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
131
132
            }

133
        allocator = CuMemAllocator.get_instance()
134
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
135
136
137
138
139
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
140
141
142
143
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
144

145
    def wake_up(self, tags: list[str] | None = None) -> None:
146
147
        from vllm.device_allocator.cumem import CuMemAllocator

148
        allocator = CuMemAllocator.get_instance()
149
        allocator.wake_up(tags)
150

151
152
153
154
155
156
157
158
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

159
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
160
161
162
163
164
165
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
166
167
                    "Sleep mode can only be used for one instance per process."
                )
168
169
170
171
172
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

173
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
174
175
176
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

177
    def init_device(self):
178
179
180
        if self.device_config.device.type == "cuda":
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )

204
            self.device = torch.device(f"cuda:{self.local_rank}")
205
            current_platform.set_device(self.device)
206

207
            current_platform.check_if_supports_dtype(self.model_config.dtype)
208
209
210
211
212

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
213
214
215
216
217
218
219
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
220
221
222
223
224

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
225
226
            gc.collect()
            torch.cuda.empty_cache()
227
228
229

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
230
231
232
233
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
234
            if self.init_snapshot.free_memory < self.requested_memory:
235
236
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
237
238
239
240
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
241
                    f"({self.cache_config.gpu_memory_utilization}, "
242
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
243
244
                    f"utilization or reduce GPU memory used by other processes."
                )
245
        else:
246
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
247

248
        # Construct the model runner
249
        self.model_runner: GPUModelRunner = GPUModelRunner(
250
251
            self.vllm_config, self.device
        )
252

253
254
255
256
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

257
258
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
259
    def load_model(self) -> None:
260
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
261
        with self._maybe_get_memory_pool_context(tag="weights"):
262
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
263

264
265
266
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

267
    def reload_weights(self) -> None:
268
        self.model_runner.reload_weights()
269

270
    @torch.inference_mode()
271
    def determine_available_memory(self) -> int:
272
        """Profiles the peak memory usage of the model to determine how much
273
        memory can be used for KV cache without OOMs.
274
275

        The engine will first conduct a profiling of the existing memory usage.
276
        Then, it calculates the free memory that can be used for KV cache in
277
        bytes.
278

279
280
281
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
282
        """
283
284
285
286
287
288
289
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
290
291
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
292
                "KV Cache as specified by kv_cache_memory_bytes config and "
293
                "skipped memory profiling. This does not respect the "
294
295
296
297
298
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
299
300
                "correspondingly."
            )
301
302
303
            logger.info(msg)
            return kv_cache_memory_bytes

304
        torch.cuda.empty_cache()
305
        torch.cuda.reset_peak_memory_stats()
306
307
308

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
309
        with memory_profiling(
310
311
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
312
        ) as profile_result:
313
            self.model_runner.profile_run()
314

315
316
317
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

318
        free_gpu_memory = profile_result.after_profile.free_memory
319
320
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
321
        assert self.init_snapshot.free_memory > free_gpu_memory, (
322
            "Error in memory profiling. "
323
324
325
326
327
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
328
329
330
331
332
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
333

334
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
335
        logger.debug(
336
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
337
338
339
340
341
342
343
344
345
346
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
347
        logger.debug(profile_result)
348
        logger.info_once(
349
350
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
351
            scope="local",
352
        )
353
        gc.collect()
354

355
        return int(self.available_kv_cache_memory_bytes)
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

372
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
373
374
        return self.model_runner.get_kv_cache_spec()

375
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
376
        """Allocate GPU KV cache with the specified kv_cache_config."""
377

378
379
380
381
382
383
384
385
386
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
        connector_vllm_config = copy.copy(self.vllm_config)
        connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
        ensure_kv_transfer_initialized(connector_vllm_config)

387
        if self.vllm_config.model_config.enable_sleep_mode:
388
389
            from vllm.device_allocator.cumem import CuMemAllocator

390
391
392
393
394
395
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
396
397

    def compile_or_warm_up_model(self) -> None:
398
399
400
401
402
403
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
404
405
406
                x
                for x in warmup_sizes
                if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
407
            ]
408
        # We skip EPLB here since we don't want to record dummy metrics
409
410
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
411
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
412
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
413

414
415
416
417
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

418
        cuda_graph_memory_bytes = 0
419
        if not self.model_config.enforce_eager:
420
421
            cuda_graph_memory_bytes = self.model_runner.capture_model()

422
423
424
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
425
426
427
428
429
430
431
432
433
434
435
436
437
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
438
439
440
441
442
443
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
444
            kv_cache_memory_bytes_to_gpu_limit = (
445
446
447
448
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
449
            kv_cache_memory_bytes_to_requested_limit = (
450
451
452
453
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
454
455
456
457
458
459
460
461
462
463
464
465
466
467

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
468
469
470
471
472
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
473
                f"utilize gpu memory. Current kv cache memory in use is "
474
475
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
476

477
            logger.debug(msg)
478
479
480
481
482
483

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
484
        if get_pp_group().is_last_rank:
485
486
487
488
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
489

490
            # We skip EPLB here since we don't want to record dummy metrics
491
492
493
494
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
495
496
497
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
498
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
499

500
501
502
503
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

504
505
506
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

507
508
509
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

510
511
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
512

513
514
515
516
517
518
    @torch.inference_mode()
    def sample_tokens(
        self, grammar_output: "GrammarOutput"
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

519
520
    @torch.inference_mode()
    def execute_model(
521
522
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
523
        intermediate_tensors = None
524
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
525
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
526
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
527
        all_gather_tensors = {
528
529
530
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
531
        }
532
        if forward_pass and not get_pp_group().is_first_rank:
533
534
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
535
                    all_gather_group=get_tp_group(),
536
537
538
                    all_gather_tensors=all_gather_tensors,
                )
            )
539

540
        output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
541
        if isinstance(output, (ModelRunnerOutput, NoneType)):
542
            return output
543

544
        assert isinstance(output, IntermediateTensors)
545
        parallel_config = self.vllm_config.parallel_config
546
        assert (
547
            parallel_config.distributed_executor_backend != "external_launcher"
548
549
            and not get_pp_group().is_last_rank
        )
550

551
552
553
554
555
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
556
557
558
559
560
561
562

        kv_connector_output = output.kv_connector_output
        if not kv_connector_output:
            return None

        # In case of PP with kv transfer, we need to pass through the
        # kv_connector_output
563
        if kv_connector_output.is_empty():
564
            return EMPTY_MODEL_RUNNER_OUTPUT
565

566
567
        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
568
        return output
569

570
    def take_draft_token_ids(self) -> DraftTokenIds | None:
571
572
        return self.model_runner.take_draft_token_ids()

573
    def profile(self, is_start: bool = True):
574
575
576
577
578
579
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
580
581
            # only print profiler results on rank 0
            if self.local_rank == 0:
582
583
584
                print(
                    self.profiler.key_averages().table(sort_by="self_cuda_time_total")
                )
585

586
    def execute_dummy_batch(self) -> None:
587
        self.model_runner._dummy_run(1, uniform_decode=True)
588

589
590
591
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

592
593
594
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

595
    def list_loras(self) -> set[int]:
596
597
598
599
600
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

601
602
603
604
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

605
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
606
        from vllm.distributed.parallel_state import get_ep_group
607

608
        if get_ep_group().rank == 0:
609
610
611
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
612
613
614
615
616
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
617
618
619
620
621
622
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=None,
            rank_mapping=rank_mapping,
        )
623
624
625
626
627
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
628
629
630
        self,
        old_ep_size: int,
        new_ep_size: int,
631
        global_expert_load: torch.Tensor | None,
632
    ) -> None:
633
        from vllm.distributed.parallel_state import get_ep_group
634

635
        if get_ep_group().rank == 0:
636
637
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
638
639
640
641
642
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
643
644
            rank_mapping=rank_mapping,
        )
645
646
647
648
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
649
650
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
651
652
653
654
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
655
656
657
658
659
660
661
662
663
664
665
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
666
                reconfig_request.new_data_parallel_rank_local
667
668
            )
        parallel_config.data_parallel_master_ip = (
669
            reconfig_request.new_data_parallel_master_ip
670
671
        )
        parallel_config.data_parallel_master_port = (
672
            reconfig_request.new_data_parallel_master_port
673
        )
674

675
676
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
677
    ) -> torch.Tensor | None:
678
679
680
681
682
683
684
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
685
686
687
688
689
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
        from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
690
691
692

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
693
694
695
696
697
698
            module
            for module in self.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
699
700
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
701
702
703
704
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
705
706
707
708
709
710
711
712
713
714
715
716
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
717
            new_physical_experts = (
718
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
719
            )
720
            parallel_config.eplb_config.num_redundant_experts = (
721
722
723
                new_physical_experts
                - self.model_runner.eplb_state.logical_replica_count.shape[1]
            )
724
725
            global_expert_load = None
        else:
726
727
728
729
730
731
            num_local_physical_experts = torch.tensor(
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
                num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
            )
732
733
734
735
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
736
737
                self.model_runner.model, execute_shuffle=False
            )
738
            parallel_config.eplb_config.num_redundant_experts = (
739
740
                new_physical_experts - global_expert_load.shape[1]
            )
741
742
743
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
744
745
            num_local_physical_experts=num_local_physical_experts,
        )
746
747
748
        return global_expert_load

    def reinitialize_distributed(
749
750
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
751
752
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
753
754
755
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
756
757
758

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
759
760
761
762
763
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
764
765
766
767
768
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

769
770
771
772
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
773
774
775
776
777
778
779
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
780
781
782
783
784
785
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
786
787
788
789
790

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
791
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
792

793
794
795
    def save_sharded_state(
        self,
        path: str,
796
797
        pattern: str | None = None,
        max_size: int | None = None,
798
    ) -> None:
799
        from vllm.model_executor.model_loader import ShardedStateLoader
800

801
802
803
804
805
806
807
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

808
809
810
811
812
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
813
814
            tensorizer_config=tensorizer_config,
        )
815

816
    def shutdown(self) -> None:
817
818
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
819

820
821

def init_worker_distributed_environment(
822
    vllm_config: VllmConfig,
823
    rank: int,
824
    distributed_init_method: str | None = None,
825
    local_rank: int = -1,
826
    backend: str = "nccl",
827
828
) -> None:
    """Initialize the distributed environment."""
829
    parallel_config = vllm_config.parallel_config
830
831
832
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
833
834
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

835
836
837
    init_distributed_environment(
        parallel_config.world_size, rank, distributed_init_method, local_rank, backend
    )
838

839
840
841
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
842
843
        parallel_config.decode_context_parallel_size,
    )