"vscode:/vscode.git/clone" did not exist on "d06ba4ed3f9a5929eabb404842a5c02da42e960b"
gpu_worker.py 37.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10

11
import numpy as np
12
13
import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig
18
19
20
21
22
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
23
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
24
25
26
27
28
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
29
from vllm.distributed.parallel_state import (
30
    get_pcp_group,
31
32
33
    get_pp_group,
    get_tp_group,
)
34
from vllm.logger import init_logger
35
from vllm.lora.request import LoRARequest
36
from vllm.model_executor import set_random_seed
37
from vllm.model_executor.models.interfaces import is_mixture_of_experts
38
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
39
from vllm.platforms import current_platform
40
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
41
from vllm.sequence import IntermediateTensors
42
from vllm.tasks import SupportedTask
43
44
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
46
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
47
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
48
49
50
51
52
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
53
from vllm.v1.utils import report_usage_stats
54
from vllm.v1.worker.utils import is_residual_scattered_for_sp
55
from vllm.v1.worker.worker_base import WorkerBase
56
57
58
59

logger = init_logger(__name__)

if TYPE_CHECKING:
60
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
61
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
62
63


64
class Worker(WorkerBase):
65
66
    def __init__(
        self,
67
        vllm_config: VllmConfig,
68
69
70
        local_rank: int,
        rank: int,
        distributed_init_method: str,
71
        is_driver_worker: bool = False,
72
    ):
73
74
75
76
77
78
79
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
80
81
82

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
83
            from vllm.utils.import_utils import init_cached_hf_modules
84

85
86
            init_cached_hf_modules()

87
88
89
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

90
        # Torch/CUDA profiler. Enabled and configured through env vars:
91
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
92
93
        # VLLM_TORCH_CUDA_PROFILE=1
        self.profiler: Any | None = None
94
        if envs.VLLM_TORCH_PROFILER_DIR:
95
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
96
97
            self.profiler = TorchProfilerWrapper(
                worker_name=worker_name, local_rank=self.local_rank
98
            )
99
100
        elif envs.VLLM_TORCH_CUDA_PROFILE:
            self.profiler = CudaProfilerWrapper()
101
102
        else:
            self.profiler = None
103

Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

106
    def sleep(self, level: int = 1) -> None:
107
108
        from vllm.device_allocator.cumem import CuMemAllocator

109
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
110
111
112
113
114

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
115
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
116
117
            }

118
        allocator = CuMemAllocator.get_instance()
119
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
120
121
122
123
124
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
125
126
127
128
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
129

130
    def wake_up(self, tags: list[str] | None = None) -> None:
131
132
        from vllm.device_allocator.cumem import CuMemAllocator

133
        allocator = CuMemAllocator.get_instance()
134
        allocator.wake_up(tags)
135

136
137
138
139
140
141
142
143
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

144
145
146
147
148
149
150
151
152
153
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

154
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
155
156
157
158
159
160
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
161
162
                    "Sleep mode can only be used for one instance per process."
                )
163
            return allocator.use_memory_pool(tag=tag)
164
        else:
165
            return nullcontext()
166

167
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
168
169
170
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

171
    def init_device(self):
172
173
        device = self.device_config.device
        if isinstance(device, torch.device) and device.type == "cuda":
174
175
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
176
177
178
179
180
181
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
182
                and self.vllm_config.parallel_config.nnodes_within_dp == 1
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
199
200
201
202
203
204
205
206
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
207
            self.device = torch.device(f"cuda:{self.local_rank}")
208
            current_platform.set_device(self.device)
209

210
            current_platform.check_if_supports_dtype(self.model_config.dtype)
211
212
213
214
215

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
216
217
218
219
220
221
222
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
223
224
225
226
227

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
228
229
            gc.collect()
            torch.cuda.empty_cache()
230
231
232

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
233
234
235
236
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
237
            if self.init_snapshot.free_memory < self.requested_memory:
238
239
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
240
241
242
243
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
244
                    f"({self.cache_config.gpu_memory_utilization}, "
245
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
246
247
                    f"utilization or reduce GPU memory used by other processes."
                )
248
        else:
249
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
250

251
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254
255
256
257
258
259
260
261
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
262
263
264
265
266
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
267

268
269
270
271
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

272
273
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
274
    def load_model(self) -> None:
275
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
276
        with self._maybe_get_memory_pool_context(tag="weights"):
277
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
278

279
280
281
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

282
    def reload_weights(self) -> None:
283
        self.model_runner.reload_weights()
284

285
    @torch.inference_mode()
286
    def determine_available_memory(self) -> int:
287
        """Profiles the peak memory usage of the model to determine how much
288
        memory can be used for KV cache without OOMs.
289
290

        The engine will first conduct a profiling of the existing memory usage.
291
        Then, it calculates the free memory that can be used for KV cache in
292
        bytes.
293

294
295
296
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
297
        """
298
299
300
301
302
303
304
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
305
306
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
307
                "KV Cache as specified by kv_cache_memory_bytes config and "
308
                "skipped memory profiling. This does not respect the "
309
310
311
312
313
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
314
315
                "correspondingly."
            )
316
317
318
            logger.info(msg)
            return kv_cache_memory_bytes

319
        torch.cuda.empty_cache()
320
        torch.cuda.reset_peak_memory_stats()
321
322
323

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
324
        with memory_profiling(
325
326
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
327
        ) as profile_result:
328
            self.model_runner.profile_run()
329

330
331
332
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

333
        free_gpu_memory = profile_result.after_profile.free_memory
334
335
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
336
        assert self.init_snapshot.free_memory > free_gpu_memory, (
337
            "Error in memory profiling. "
338
339
340
341
342
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
343
344
345
346
347
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
348

349
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
350
        logger.debug(
351
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
352
353
354
355
356
357
358
359
360
361
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
362
        logger.debug(profile_result)
363
        logger.info_once(
364
365
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
366
            scope="local",
367
        )
368
        gc.collect()
369

370
        return int(self.available_kv_cache_memory_bytes)
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

387
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
388
389
        return self.model_runner.get_kv_cache_spec()

390
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
391
        """Allocate GPU KV cache with the specified kv_cache_config."""
392

393
394
395
396
397
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
398
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
399

400
        if self.vllm_config.model_config.enable_sleep_mode:
401
402
            from vllm.device_allocator.cumem import CuMemAllocator

403
            allocator = CuMemAllocator.get_instance()
404
405
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
406
407
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
408
409

    def compile_or_warm_up_model(self) -> None:
410
411
412
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
413
414
        compile_sizes = self.vllm_config.compilation_config.compile_sizes
        warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
415
        if not self.model_config.enforce_eager:
416
417
418
            capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
            if capture_sizes is not None:
                warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes]
419
        # We skip EPLB here since we don't want to record dummy metrics
420
421
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
422
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
423
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
424

425
426
427
428
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

429
        cuda_graph_memory_bytes = 0
430
        if not self.model_config.enforce_eager:
431
432
            cuda_graph_memory_bytes = self.model_runner.capture_model()

433
434
435
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
436
437
438
439
440
441
442
443
444
445
446
447
448
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
449
450
451
452
453
454
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
455
            kv_cache_memory_bytes_to_gpu_limit = (
456
457
458
459
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
460
            kv_cache_memory_bytes_to_requested_limit = (
461
462
463
464
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
465
466
467
468
469
470
471
472
473
474
475
476
477
478

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
479
480
481
482
483
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
484
                f"utilize gpu memory. Current kv cache memory in use is "
485
486
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
487

488
            logger.debug(msg)
489
490
491
492
493
494

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
495
        if get_pp_group().is_last_rank:
496
497
498
499
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
500

501
            # We skip EPLB here since we don't want to record dummy metrics
502
503
504
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
505
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
506
            )
507
508
509
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
510
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
511

512
513
514
515
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

516
517
518
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

519
520
521
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

522
523
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
524

525
526
527
528
529
530
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

531
532
        self.profiler.step()

533
534
535
        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

536
        return self.profiler.annotate_context_manager(
537
538
539
            f"execute_new_{num_new}_cached_{num_cached}"
        )

540
541
    @torch.inference_mode()
    def sample_tokens(
542
        self, grammar_output: "GrammarOutput | None"
543
544
545
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

546
547
    @torch.inference_mode()
    def execute_model(
548
549
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
550
        intermediate_tensors = None
551
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
552
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
553
554
555
556
557
558
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
559
            and compilation_config.pass_config.enable_sp
560
561
562
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
563
            assert not self.use_v2_model_runner
564
565
566
567
568
569
570
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
571
            _, batch_desc, _, _, _ = (
572
573
574
575
576
577
578
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
579
            )
580
581
582
583
584
585
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

586
        if forward_pass and not get_pp_group().is_first_rank:
587
588
589
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
590
            )
591
592
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
593

594
595
596
597
598
599
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
600

601
        assert isinstance(output, IntermediateTensors)
602
        parallel_config = self.vllm_config.parallel_config
603
        assert (
604
            parallel_config.distributed_executor_backend != "external_launcher"
605
606
            and not get_pp_group().is_last_rank
        )
607

608
609
610
611
612
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
613

614
        return None
615

616
    def take_draft_token_ids(self) -> DraftTokenIds | None:
617
618
        return self.model_runner.take_draft_token_ids()

619
    def profile(self, is_start: bool = True):
620
        if self.profiler is None:
621
            raise RuntimeError("Profiling is not enabled.")
622
623
624
625
626
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

627
    def execute_dummy_batch(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
628
629
630
631
632
633
        if self.use_v2_model_runner:
            self.model_runner.execute_model(
                SchedulerOutput.make_empty(), dummy_run=True
            )
        else:
            self.model_runner._dummy_run(1, uniform_decode=True)
634

635
636
637
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

638
639
640
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

641
    def list_loras(self) -> set[int]:
642
643
644
645
646
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

647
648
649
650
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

651
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
652
        from vllm.distributed.parallel_state import get_ep_group
653

654
        if get_ep_group().rank == 0:
655
656
657
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
658
659
660
661
662
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
663
664
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
665
            global_expert_loads=None,
666
667
            rank_mapping=rank_mapping,
        )
668
669
670
671
672
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
673
674
675
        self,
        old_ep_size: int,
        new_ep_size: int,
676
        global_expert_loads: list[torch.Tensor] | None,
677
    ) -> None:
678
        from vllm.distributed.parallel_state import get_ep_group
679

680
        if get_ep_group().rank == 0:
681
682
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
683
684
685
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
686
            global_expert_loads=global_expert_loads,
687
688
            rank_mapping=rank_mapping,
        )
689
690
691
692
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
693
694
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
695
696
697
698
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
699
700
701
702
703
704
705
706
707
708
709
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
710
                reconfig_request.new_data_parallel_rank_local
711
712
            )
        parallel_config.data_parallel_master_ip = (
713
            reconfig_request.new_data_parallel_master_ip
714
715
        )
        parallel_config.data_parallel_master_port = (
716
            reconfig_request.new_data_parallel_master_port
717
        )
718

719
720
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
721
    ) -> list[torch.Tensor] | None:
722
723
724
725
726
727
728
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
729
730
731
732
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
733
734
735
736
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
737
738

        parallel_config = self.vllm_config.parallel_config
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
760
                    pcp_size_=get_pcp_group().world_size,
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

784
785
786
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
787
            new_physical_experts = (
788
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
789
            )
790
            parallel_config.eplb_config.num_redundant_experts = (
791
                new_physical_experts
792
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
793
            )
794
            global_expert_loads = None
795
        else:
796
            num_local_physical_experts_tensor = torch.tensor(
797
798
799
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
800
801
802
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
803
            )
804
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
805
806
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
807
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
808
                execute_shuffle=False
809
            )
810
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
811
            parallel_config.eplb_config.num_redundant_experts = (
812
                new_physical_experts - global_expert_loads[0].shape[1]
813
            )
814
        prepare_communication_buffer_for_model(self.model_runner.model)
815
816
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
817
818
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
819
820
            num_local_physical_experts=num_local_physical_experts,
        )
821
        return global_expert_loads
822
823

    def reinitialize_distributed(
824
825
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
826
827
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
828
829
830
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
831
832
833

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
834
835
836
837
838
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
839
840
841
842
843
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

844
845
846
847
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
848
849
850
851
852
853
854
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
855
856
857
858
859
860
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
861

862
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
863
864

        if new_ep_size > old_ep_size:
865
866
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
867

868
869
870
    def save_sharded_state(
        self,
        path: str,
871
872
        pattern: str | None = None,
        max_size: int | None = None,
873
    ) -> None:
874
        from vllm.model_executor.model_loader import ShardedStateLoader
875

876
877
878
879
880
881
882
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

883
884
885
886
887
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
888
889
            tensorizer_config=tensorizer_config,
        )
890

891
    def shutdown(self) -> None:
892
893
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
894
895
        if self.profiler is not None:
            self.profiler.shutdown()
896

897
898

def init_worker_distributed_environment(
899
    vllm_config: VllmConfig,
900
    rank: int,
901
    distributed_init_method: str | None = None,
902
    local_rank: int = -1,
903
    backend: str = "nccl",
904
905
) -> None:
    """Initialize the distributed environment."""
906
    parallel_config = vllm_config.parallel_config
907
908
909
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
910
911
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

912
    init_method = distributed_init_method or "env://"
913
    init_distributed_environment(
914
        parallel_config.world_size, rank, init_method, local_rank, backend
915
    )
916

917
918
919
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
920
        parallel_config.prefill_context_parallel_size,
921
922
        parallel_config.decode_context_parallel_size,
    )
923
924
925
926

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)