"vllm/vscode:/vscode.git/clone" did not exist on "36c260dad604ccc845150753f2530b5b2ba9d7e6"
gpu_worker.py 35.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10
11
12

import torch
import torch.distributed
13
import torch.nn as nn
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
18
19
20
21
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
22
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
23
24
25
26
27
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
28
from vllm.distributed.parallel_state import (
29
    get_pcp_group,
30
31
32
    get_pp_group,
    get_tp_group,
)
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
from vllm.model_executor import set_random_seed
36
from vllm.model_executor.models.interfaces import is_mixture_of_experts
37
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
38
from vllm.platforms import current_platform
39
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
40
from vllm.sequence import IntermediateTensors
41
from vllm.tasks import SupportedTask
42
43
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
Woosuk Kwon's avatar
Woosuk Kwon committed
44
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
45
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
46
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
47
48
49
50
51
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
52
from vllm.v1.utils import report_usage_stats
53
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
54
from vllm.v1.worker.utils import is_residual_scattered_for_sp
55
from vllm.v1.worker.worker_base import WorkerBase
56
57
58
59

logger = init_logger(__name__)

if TYPE_CHECKING:
60
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
61
62


63
class Worker(WorkerBase):
64
65
    def __init__(
        self,
66
        vllm_config: VllmConfig,
67
68
69
        local_rank: int,
        rank: int,
        distributed_init_method: str,
70
        is_driver_worker: bool = False,
71
    ):
72
73
74
75
76
77
78
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
79
80
81

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
82
            from vllm.utils.import_utils import init_cached_hf_modules
83

84
85
            init_cached_hf_modules()

86
87
88
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

89
        # Torch/CUDA profiler. Enabled and configured through env vars:
90
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
91
92
        # VLLM_TORCH_CUDA_PROFILE=1
        self.profiler: Any | None = None
93
        if envs.VLLM_TORCH_PROFILER_DIR:
94
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
95
96
            self.profiler = TorchProfilerWrapper(
                worker_name=worker_name, local_rank=self.local_rank
97
            )
98
99
        elif envs.VLLM_TORCH_CUDA_PROFILE:
            self.profiler = CudaProfilerWrapper()
100
101
        else:
            self.profiler = None
102

Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

105
    def sleep(self, level: int = 1) -> None:
106
107
        from vllm.device_allocator.cumem import CuMemAllocator

108
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
109
110
111
112
113

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
114
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
115
116
            }

117
        allocator = CuMemAllocator.get_instance()
118
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
119
120
121
122
123
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
124
125
126
127
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
128

129
    def wake_up(self, tags: list[str] | None = None) -> None:
130
131
        from vllm.device_allocator.cumem import CuMemAllocator

132
        allocator = CuMemAllocator.get_instance()
133
        allocator.wake_up(tags)
134

135
136
137
138
139
140
141
142
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

143
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
144
145
146
147
148
149
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
150
151
                    "Sleep mode can only be used for one instance per process."
                )
152
            return allocator.use_memory_pool(tag=tag)
153
        else:
154
            return nullcontext()
155

156
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
157
158
159
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

160
    def init_device(self):
161
162
        device = self.device_config.device
        if isinstance(device, torch.device) and device.type == "cuda":
163
164
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
165
166
167
168
169
170
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
171
                and self.vllm_config.parallel_config.nnodes_within_dp == 1
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
188
189
190
191
192
193
194
195
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
196
            self.device = torch.device(f"cuda:{self.local_rank}")
197
            current_platform.set_device(self.device)
198

199
            current_platform.check_if_supports_dtype(self.model_config.dtype)
200
201
202
203
204

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
205
206
207
208
209
210
211
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
212
213
214
215
216

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
217
218
            gc.collect()
            torch.cuda.empty_cache()
219
220
221

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
222
223
224
225
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
226
            if self.init_snapshot.free_memory < self.requested_memory:
227
228
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
229
230
231
232
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
233
                    f"({self.cache_config.gpu_memory_utilization}, "
234
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
235
236
                    f"utilization or reduce GPU memory used by other processes."
                )
237
        else:
238
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
239

240
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
243
244
245
246
247
248
249
250
251
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
            self.model_runner = GPUModelRunner(self.vllm_config, self.device)
252

253
254
255
256
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

257
258
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
259
    def load_model(self) -> None:
260
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
261
        with self._maybe_get_memory_pool_context(tag="weights"):
262
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
263

264
265
266
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

267
    def reload_weights(self) -> None:
268
        self.model_runner.reload_weights()
269

270
    @torch.inference_mode()
271
    def determine_available_memory(self) -> int:
272
        """Profiles the peak memory usage of the model to determine how much
273
        memory can be used for KV cache without OOMs.
274
275

        The engine will first conduct a profiling of the existing memory usage.
276
        Then, it calculates the free memory that can be used for KV cache in
277
        bytes.
278

279
280
281
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
282
        """
283
284
285
286
287
288
289
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
290
291
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
292
                "KV Cache as specified by kv_cache_memory_bytes config and "
293
                "skipped memory profiling. This does not respect the "
294
295
296
297
298
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
299
300
                "correspondingly."
            )
301
302
303
            logger.info(msg)
            return kv_cache_memory_bytes

304
        torch.cuda.empty_cache()
305
        torch.cuda.reset_peak_memory_stats()
306
307
308

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
309
        with memory_profiling(
310
311
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
312
        ) as profile_result:
313
            self.model_runner.profile_run()
314

315
316
317
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

318
        free_gpu_memory = profile_result.after_profile.free_memory
319
320
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
321
        assert self.init_snapshot.free_memory > free_gpu_memory, (
322
            "Error in memory profiling. "
323
324
325
326
327
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
328
329
330
331
332
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
333

334
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
335
        logger.debug(
336
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
337
338
339
340
341
342
343
344
345
346
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
347
        logger.debug(profile_result)
348
        logger.info_once(
349
350
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
351
            scope="local",
352
        )
353
        gc.collect()
354

355
        return int(self.available_kv_cache_memory_bytes)
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

372
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
373
374
        return self.model_runner.get_kv_cache_spec()

375
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
376
        """Allocate GPU KV cache with the specified kv_cache_config."""
377

378
379
380
381
382
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
383
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
384

385
        if self.vllm_config.model_config.enable_sleep_mode:
386
387
            from vllm.device_allocator.cumem import CuMemAllocator

388
            allocator = CuMemAllocator.get_instance()
389
390
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
391
392
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
393
394

    def compile_or_warm_up_model(self) -> None:
395
396
397
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
398
399
        compile_sizes = self.vllm_config.compilation_config.compile_sizes
        warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
400
        if not self.model_config.enforce_eager:
401
402
403
            capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
            if capture_sizes is not None:
                warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes]
404
        # We skip EPLB here since we don't want to record dummy metrics
405
406
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
407
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
408
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
409

410
411
412
413
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

414
        cuda_graph_memory_bytes = 0
415
        if not self.model_config.enforce_eager:
416
417
            cuda_graph_memory_bytes = self.model_runner.capture_model()

418
419
420
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
421
422
423
424
425
426
427
428
429
430
431
432
433
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
434
435
436
437
438
439
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
440
            kv_cache_memory_bytes_to_gpu_limit = (
441
442
443
444
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
445
            kv_cache_memory_bytes_to_requested_limit = (
446
447
448
449
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
450
451
452
453
454
455
456
457
458
459
460
461
462
463

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
464
465
466
467
468
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
469
                f"utilize gpu memory. Current kv cache memory in use is "
470
471
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
472

473
            logger.debug(msg)
474
475
476
477
478
479

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
480
        if get_pp_group().is_last_rank:
481
482
483
484
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
485

486
            # We skip EPLB here since we don't want to record dummy metrics
487
488
489
490
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
            )
491
492
493
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
494
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
495

496
497
498
499
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

500
501
502
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

503
504
505
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

506
507
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
508

509
510
511
512
513
514
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

515
516
        self.profiler.step()

517
518
519
        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

520
        return self.profiler.annotate_context_manager(
521
522
523
            f"execute_new_{num_new}_cached_{num_cached}"
        )

524
525
    @torch.inference_mode()
    def sample_tokens(
526
        self, grammar_output: "GrammarOutput | None"
527
528
529
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

530
531
    @torch.inference_mode()
    def execute_model(
532
533
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
534
        intermediate_tensors = None
535
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
536
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
537
        num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
538
        all_gather_tensors = {
539
540
541
            "residual": not is_residual_scattered_for_sp(
                self.vllm_config, num_input_tokens
            )
542
        }
543
        if forward_pass and not get_pp_group().is_first_rank:
544
545
546
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
547
            )
548
549
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
550

551
552
553
554
555
556
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
557

558
        assert isinstance(output, IntermediateTensors)
559
        parallel_config = self.vllm_config.parallel_config
560
        assert (
561
            parallel_config.distributed_executor_backend != "external_launcher"
562
563
            and not get_pp_group().is_last_rank
        )
564

565
566
567
568
569
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
570

571
        return None
572

573
    def take_draft_token_ids(self) -> DraftTokenIds | None:
574
575
        return self.model_runner.take_draft_token_ids()

576
    def profile(self, is_start: bool = True):
577
        if self.profiler is None:
578
            raise RuntimeError("Profiling is not enabled.")
579
580
581
582
583
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

584
    def execute_dummy_batch(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
585
586
587
588
589
590
        if self.use_v2_model_runner:
            self.model_runner.execute_model(
                SchedulerOutput.make_empty(), dummy_run=True
            )
        else:
            self.model_runner._dummy_run(1, uniform_decode=True)
591

592
593
594
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

595
596
597
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

598
    def list_loras(self) -> set[int]:
599
600
601
602
603
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

604
605
606
607
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

608
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
609
        from vllm.distributed.parallel_state import get_ep_group
610

611
        if get_ep_group().rank == 0:
612
613
614
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
615
616
617
618
619
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
620
621
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
622
            global_expert_loads=None,
623
624
            rank_mapping=rank_mapping,
        )
625
626
627
628
629
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
630
631
632
        self,
        old_ep_size: int,
        new_ep_size: int,
633
        global_expert_loads: list[torch.Tensor] | None,
634
    ) -> None:
635
        from vllm.distributed.parallel_state import get_ep_group
636

637
        if get_ep_group().rank == 0:
638
639
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
640
641
642
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
643
            global_expert_loads=global_expert_loads,
644
645
            rank_mapping=rank_mapping,
        )
646
647
648
649
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
650
651
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
652
653
654
655
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
656
657
658
659
660
661
662
663
664
665
666
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
667
                reconfig_request.new_data_parallel_rank_local
668
669
            )
        parallel_config.data_parallel_master_ip = (
670
            reconfig_request.new_data_parallel_master_ip
671
672
        )
        parallel_config.data_parallel_master_port = (
673
            reconfig_request.new_data_parallel_master_port
674
        )
675

676
677
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
678
    ) -> list[torch.Tensor] | None:
679
680
681
682
683
684
685
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
686
687
688
689
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
690
691
692
693
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
694
695

        parallel_config = self.vllm_config.parallel_config
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
717
                    pcp_size_=get_pcp_group().world_size,
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

741
742
743
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
744
            new_physical_experts = (
745
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
746
            )
747
            parallel_config.eplb_config.num_redundant_experts = (
748
                new_physical_experts
749
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
750
            )
751
            global_expert_loads = None
752
        else:
753
            num_local_physical_experts_tensor = torch.tensor(
754
755
756
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
757
758
759
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
760
            )
761
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
762
763
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
764
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
765
                execute_shuffle=False
766
            )
767
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
768
            parallel_config.eplb_config.num_redundant_experts = (
769
                new_physical_experts - global_expert_loads[0].shape[1]
770
            )
771
        prepare_communication_buffer_for_model(self.model_runner.model)
772
773
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
774
775
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
776
777
            num_local_physical_experts=num_local_physical_experts,
        )
778
        return global_expert_loads
779
780

    def reinitialize_distributed(
781
782
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
783
784
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
785
786
787
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
788
789
790

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
791
792
793
794
795
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
796
797
798
799
800
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

801
802
803
804
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
805
806
807
808
809
810
811
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
812
813
814
815
816
817
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
818

819
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
820
821

        if new_ep_size > old_ep_size:
822
823
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
824

825
826
827
    def save_sharded_state(
        self,
        path: str,
828
829
        pattern: str | None = None,
        max_size: int | None = None,
830
    ) -> None:
831
        from vllm.model_executor.model_loader import ShardedStateLoader
832

833
834
835
836
837
838
839
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

840
841
842
843
844
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
845
846
            tensorizer_config=tensorizer_config,
        )
847

848
    def shutdown(self) -> None:
849
850
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
851
852
        if self.profiler is not None:
            self.profiler.shutdown()
853

854
855

def init_worker_distributed_environment(
856
    vllm_config: VllmConfig,
857
    rank: int,
858
    distributed_init_method: str | None = None,
859
    local_rank: int = -1,
860
    backend: str = "nccl",
861
862
) -> None:
    """Initialize the distributed environment."""
863
    parallel_config = vllm_config.parallel_config
864
865
866
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
867
868
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

869
    init_method = distributed_init_method or "env://"
870
    init_distributed_environment(
871
        parallel_config.world_size, rank, init_method, local_rank, backend
872
    )
873

874
875
876
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
877
        parallel_config.prefill_context_parallel_size,
878
879
        parallel_config.decode_context_parallel_size,
    )
880
881
882
883

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)