gpu_worker.py 38.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4

5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from types import NoneType
9
from typing import TYPE_CHECKING, Any, cast
10

11
import numpy as np
12
13
import torch
import torch.distributed
14
import torch.nn as nn
15

16
import vllm.envs as envs
17
from vllm.config import CUDAGraphMode, VllmConfig
18
from vllm.config.compilation import CompilationMode
19
20
21
22
23
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
    set_custom_all_reduce,
)
24
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
25
26
27
28
29
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
    get_kv_transfer_group,
    has_kv_transfer_group,
)
30
from vllm.distributed.parallel_state import (
31
    get_pcp_group,
32
33
34
    get_pp_group,
    get_tp_group,
)
35
from vllm.logger import init_logger
36
from vllm.lora.request import LoRARequest
37
from vllm.model_executor import set_random_seed
38
from vllm.model_executor.models.interfaces import is_mixture_of_experts
39
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
40
from vllm.platforms import current_platform
41
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
42
from vllm.sequence import IntermediateTensors
43
from vllm.tasks import SupportedTask
44
45
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
Woosuk Kwon's avatar
Woosuk Kwon committed
46
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
47
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
48
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
49
50
51
52
53
from vllm.v1.outputs import (
    AsyncModelRunnerOutput,
    DraftTokenIds,
    ModelRunnerOutput,
)
54
from vllm.v1.utils import report_usage_stats
55
from vllm.v1.worker.utils import is_residual_scattered_for_sp
56
from vllm.v1.worker.worker_base import WorkerBase
57
58
59
60

logger = init_logger(__name__)

if TYPE_CHECKING:
61
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
62
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner
63
64


65
class Worker(WorkerBase):
66
67
    def __init__(
        self,
68
        vllm_config: VllmConfig,
69
70
71
        local_rank: int,
        rank: int,
        distributed_init_method: str,
72
        is_driver_worker: bool = False,
73
    ):
74
75
76
77
78
79
80
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
81
82
83

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
84
            from vllm.utils.import_utils import init_cached_hf_modules
85

86
87
            init_cached_hf_modules()

88
89
90
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

91
        # Torch/CUDA profiler. Enabled and configured through env vars:
92
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
93
94
        # VLLM_TORCH_CUDA_PROFILE=1
        self.profiler: Any | None = None
95
        if envs.VLLM_TORCH_PROFILER_DIR:
96
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
97
98
            self.profiler = TorchProfilerWrapper(
                worker_name=worker_name, local_rank=self.local_rank
99
            )
100
101
        elif envs.VLLM_TORCH_CUDA_PROFILE:
            self.profiler = CudaProfilerWrapper()
102
103
        else:
            self.profiler = None
104

Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
        self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

107
    def sleep(self, level: int = 1) -> None:
108
109
        from vllm.device_allocator.cumem import CuMemAllocator

110
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
111
112
113
114
115

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
116
                name: buffer.cpu().clone() for name, buffer in model.named_buffers()
117
118
            }

119
        allocator = CuMemAllocator.get_instance()
120
        allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
121
122
123
124
125
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
126
127
128
129
            "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
            freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes,
        )
130

131
    def wake_up(self, tags: list[str] | None = None) -> None:
132
133
        from vllm.device_allocator.cumem import CuMemAllocator

134
        allocator = CuMemAllocator.get_instance()
135
        allocator.wake_up(tags)
136

137
138
139
140
141
142
143
144
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

145
146
147
148
149
150
151
152
153
154
        # If the KV cache has just been woken up,
        # the internal state of cache_engine must be reset,
        # especially the FP8 scaling factor.
        if (
            (tags is None or "kv_cache" in tags)
            and self.cache_config.cache_dtype.startswith("fp8")
            and hasattr(self.model_runner, "init_fp8_kv_scales")
        ):
            self.model_runner.init_fp8_kv_scales()

155
    def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
156
157
158
159
160
161
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
162
163
                    "Sleep mode can only be used for one instance per process."
                )
164
            return allocator.use_memory_pool(tag=tag)
165
        else:
166
            return nullcontext()
167

168
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
169
170
171
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

172
    def init_device(self):
173
174
        device = self.device_config.device
        if isinstance(device, torch.device) and device.type == "cuda":
175
176
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
177
178
179
180
181
182
            if (
                self.parallel_config.data_parallel_size > 1
                and self.parallel_config.data_parallel_size_local > 0
                and self.parallel_config.distributed_executor_backend
                not in ["ray", "external_launcher"]
                and self.vllm_config.parallel_config.data_parallel_backend != "ray"
183
                and self.vllm_config.parallel_config.nnodes_within_dp == 1
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            ):
                # Use local DP rank if available, otherwise use global DP rank.
                dp_local_rank = self.parallel_config.data_parallel_rank_local
                if dp_local_rank is None:
                    dp_local_rank = self.parallel_config.data_parallel_rank

                tp_pp_world_size = (
                    self.parallel_config.pipeline_parallel_size
                    * self.parallel_config.tensor_parallel_size
                )

                # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
                self.local_rank += dp_local_rank * tp_pp_world_size
                assert self.local_rank < torch.cuda.device_count(), (
                    f"DP adjusted local rank {self.local_rank} is out of bounds. "
                )
200
201
202
203
204
205
206
207
                visible_device_count = (
                    torch.cuda.device_count() if torch.cuda.is_available() else 0
                )
                assert self.parallel_config.local_world_size <= visible_device_count, (
                    f"local_world_size ({self.parallel_config.local_world_size}) must "
                    f"be less than or equal to the number of visible devices "
                    f"({visible_device_count})."
                )
208
            self.device = torch.device(f"cuda:{self.local_rank}")
209
            current_platform.set_device(self.device)
210

211
            current_platform.check_if_supports_dtype(self.model_config.dtype)
212
213
214
215
216

            # Initialize the distributed environment BEFORE taking
            # memory snapshot
            # This ensures NCCL buffers are allocated before we measure
            # available memory
217
218
219
220
221
222
223
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
                current_platform.dist_backend,
            )
224
225
226
227
228

            # Set random seed.
            set_random_seed(self.model_config.seed)

            # Now take memory snapshot after NCCL is initialized
229
230
            gc.collect()
            torch.cuda.empty_cache()
231
232
233

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
234
235
236
237
            self.requested_memory = (
                self.init_snapshot.total_memory
                * self.cache_config.gpu_memory_utilization
            )
238
            if self.init_snapshot.free_memory < self.requested_memory:
239
240
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
241
242
243
244
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
245
                    f"({self.cache_config.gpu_memory_utilization}, "
246
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
247
248
                    f"utilization or reduce GPU memory used by other processes."
                )
249
        else:
250
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
251

252
        # Construct the model runner
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
255
256
257
258
259
260
261
262
        if self.use_v2_model_runner:
            from vllm.v1.worker.gpu.model_runner import (
                GPUModelRunner as GPUModelRunnerV2,
            )

            # HACK(woosuk): This is a temporary fix to avoid type errors.
            self.model_runner: GPUModelRunner = GPUModelRunnerV2(  # type: ignore
                self.vllm_config, self.device
            )
        else:
263
264
265
266
267
            from vllm.v1.worker.gpu_model_runner import (
                GPUModelRunner as GPUModelRunnerV1,
            )

            self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device)
268

269
270
271
272
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

273
274
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
275
    def load_model(self) -> None:
276
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
277
        with self._maybe_get_memory_pool_context(tag="weights"):
278
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
279

280
281
282
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

283
    def reload_weights(self) -> None:
284
        self.model_runner.reload_weights()
285

286
    @torch.inference_mode()
287
    def determine_available_memory(self) -> int:
288
        """Profiles the peak memory usage of the model to determine how much
289
        memory can be used for KV cache without OOMs.
290
291

        The engine will first conduct a profiling of the existing memory usage.
292
        Then, it calculates the free memory that can be used for KV cache in
293
        bytes.
294

295
296
297
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
298
        """
299
300
301
302
303
304
305
        GiB = lambda b: b / GiB_bytes
        if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
            # still need a profile run which compiles the model for
            # max_num_batched_tokens
            self.model_runner.profile_run()

            msg = (
306
307
                f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
                f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
308
                "KV Cache as specified by kv_cache_memory_bytes config and "
309
                "skipped memory profiling. This does not respect the "
310
311
312
313
314
                "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
                "config when you want manual control of KV cache memory "
                "size. If OOM'ed, check the difference of initial free "
                "memory between the current run and the previous run "
                "where kv_cache_memory_bytes is suggested and update it "
315
316
                "correspondingly."
            )
317
318
319
            logger.info(msg)
            return kv_cache_memory_bytes

320
        torch.cuda.empty_cache()
321
        torch.cuda.reset_peak_memory_stats()
322
323
324

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
325
        with memory_profiling(
326
327
            self.init_snapshot,
            weights_memory=int(self.model_runner.model_memory_usage),
328
        ) as profile_result:
329
            self.model_runner.profile_run()
330

331
332
333
        self.non_torch_memory = profile_result.non_torch_increase
        self.peak_activation_memory = profile_result.torch_peak_increase

334
        free_gpu_memory = profile_result.after_profile.free_memory
335
336
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
337
        assert self.init_snapshot.free_memory > free_gpu_memory, (
338
            "Error in memory profiling. "
339
340
341
342
343
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
344
345
346
347
348
            "isolate vLLM in its own container."
        )
        self.available_kv_cache_memory_bytes = (
            self.requested_memory - profile_result.non_kv_cache_memory
        )
349

350
        unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
351
        logger.debug(
352
            "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
353
354
355
356
357
358
359
360
361
362
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
363
        logger.debug(profile_result)
364
        logger.info_once(
365
366
            "Available KV cache memory: %.2f GiB",
            GiB(self.available_kv_cache_memory_bytes),
367
            scope="local",
368
        )
369
        gc.collect()
370

371
        return int(self.available_kv_cache_memory_bytes)
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    def get_kv_connector_handshake_metadata(self) -> dict | None:
        """Get KV connector metadata from this worker if available."""

        if not has_kv_transfer_group():
            return None

        connector = get_kv_transfer_group()
        # Return None for connectors that don't need to exchange handshake
        # metadata across workers.
        if (metadata := connector.get_handshake_metadata()) is None:
            return None

        tp_rank = get_tp_group().rank_in_group
        return {tp_rank: metadata}

388
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
389
390
        return self.model_runner.get_kv_cache_spec()

391
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
392
        """Allocate GPU KV cache with the specified kv_cache_config."""
393

394
395
396
397
398
        # Init kv cache connector here, because it requires
        # `kv_cache_config`.
        # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
        # because `initialize_kv_cache` will inject kv cache groups not
        # related to kv cache connector (e.g. kv cache sharing layers).
399
        ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
400

401
        if self.vllm_config.model_config.enable_sleep_mode:
402
403
            from vllm.device_allocator.cumem import CuMemAllocator

404
            allocator = CuMemAllocator.get_instance()
405
406
            with allocator.use_memory_pool(tag="kv_cache"):
                self.model_runner.initialize_kv_cache(kv_cache_config)
407
408
        else:
            self.model_runner.initialize_kv_cache(kv_cache_config)
409
410

    def compile_or_warm_up_model(self) -> None:
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        warmup_sizes = []

        if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            # warm up sizes that are not in cudagraph capture sizes,
            # but users still want to compile for better performance,
            # e.g. for the max-num-batched token size in chunked prefill.
            compile_sizes = self.vllm_config.compilation_config.compile_sizes
            warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
            cg_capture_sizes: list[int] = []

            if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
                cg_capture_sizes = [] if cg_sizes is None else cg_sizes
                warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]

            compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
            # For each compile_range, if none of the batch sizes
            # in warmup_sizes or cudagraph_capture_sizes are in the range,
            # add the end of the range to ensure compilation/warmup.
            all_sizes = set(cg_capture_sizes)
            all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
            for compile_range in compile_ranges:
                if not any(x in compile_range for x in all_sizes):
                    warmup_sizes.append(compile_range.end)

436
        # We skip EPLB here since we don't want to record dummy metrics
437
438
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
439
            self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
440
        self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
441

442
443
444
445
        # Warmup and tune the kernels used during model execution before
        # cuda graph capture.
        kernel_warmup(self)

446
        cuda_graph_memory_bytes = 0
447
        if not self.model_config.enforce_eager:
448
449
            cuda_graph_memory_bytes = self.model_runner.capture_model()

450
451
452
        if self.cache_config.kv_cache_memory_bytes is None and hasattr(
            self, "peak_activation_memory"
        ):
453
454
455
456
457
458
459
460
461
462
463
464
465
            # Suggests optimal kv cache memory size if we rely on
            # memory_profiling to guess the kv cache memory size which
            # provides peak_activation_memory and a few other memory
            # consumption. `memory_profiling` does not consider
            # CUDAGraph memory size and may not utilize all gpu memory.
            # Users may want fine-grained control to specify kv cache
            # memory size.
            GiB = lambda b: round(b / GiB_bytes, 2)

            # empirically observed that the memory profiling may
            # slightly underestimate the memory consumption.
            # So leave a small buffer (=150MiB) to avoid OOM.
            redundancy_buffer_memory = 150 * (1 << 20)
466
467
468
469
470
471
            non_kv_cache_memory = (
                self.model_runner.model_memory_usage
                + self.peak_activation_memory
                + self.non_torch_memory
                + cuda_graph_memory_bytes
            )
472
            kv_cache_memory_bytes_to_gpu_limit = (
473
474
475
476
                self.init_snapshot.free_memory
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
477
            kv_cache_memory_bytes_to_requested_limit = (
478
479
480
481
                int(self.requested_memory)
                - non_kv_cache_memory
                - redundancy_buffer_memory
            )
482
483
484
485
486
487
488
489
490
491
492
493
494
495

            msg = (
                f"Free memory on device "
                f"({GiB(self.init_snapshot.free_memory)}/"
                f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
                f"Desired GPU memory utilization is "
                f"({self.cache_config.gpu_memory_utilization}, "
                f"{GiB(self.requested_memory)} GiB). "
                f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
                f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
                f"for peak activation, {GiB(self.non_torch_memory)} GiB "
                f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
                f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
                f"config with `--kv-cache-memory="
496
497
498
499
500
                f"{kv_cache_memory_bytes_to_requested_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
                f"into requested memory, or `--kv-cache-memory="
                f"{kv_cache_memory_bytes_to_gpu_limit}` "
                f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
501
                f"utilize gpu memory. Current kv cache memory in use is "
502
503
                f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
            )
504

505
            logger.debug(msg)
506
507
508
509
510
511

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
512
        if get_pp_group().is_last_rank:
513
514
515
516
            max_num_reqs = min(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens,
            )
517

518
            # We skip EPLB here since we don't want to record dummy metrics
519
520
521
            hidden_states, last_hidden_states = self.model_runner._dummy_run(
                num_tokens=max_num_reqs,
                skip_eplb=True,
522
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
523
            )
524
525
526
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
527
                self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
528

529
530
531
532
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

533
534
535
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

536
537
538
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

539
540
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
541

542
543
544
545
546
547
    def annotate_profile(self, scheduler_output):
        # add trace annotation so that we can easily distinguish
        # new/cached request numbers in each iteration
        if not self.profiler:
            return nullcontext()

548
549
        self.profiler.step()

550
551
552
        num_new = len(scheduler_output.scheduled_new_reqs)
        num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

553
        return self.profiler.annotate_context_manager(
554
555
556
            f"execute_new_{num_new}_cached_{num_cached}"
        )

557
558
    @torch.inference_mode()
    def sample_tokens(
559
        self, grammar_output: "GrammarOutput | None"
560
561
562
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

563
564
    @torch.inference_mode()
    def execute_model(
565
566
        self, scheduler_output: "SchedulerOutput"
    ) -> ModelRunnerOutput | None:
567
        intermediate_tensors = None
568
        forward_pass = scheduler_output.total_num_scheduled_tokens > 0
569
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
570
571
572
573
574
575
        all_gather_tensors = {}
        compilation_config = self.vllm_config.compilation_config
        parallel_config = self.vllm_config.parallel_config

        if (
            parallel_config.pipeline_parallel_size > 1
576
            and compilation_config.pass_config.enable_sp
577
578
579
            and forward_pass
        ):
            # currently only supported by V1 GPUModelRunner
580
            assert not self.use_v2_model_runner
581
582
583
584
585
586
587
            num_scheduled_tokens_np = np.array(
                list(scheduler_output.num_scheduled_tokens.values()),
                dtype=np.int32,
            )
            # TODO(lucas): This is pretty gross; ideally we should only ever call
            # `_determine_batch_execution_and_padding` once (will get called again
            # in `execute_model`) but this requires a larger refactor of PP.
588
            _, batch_desc, _, _, _ = (
589
590
591
592
593
594
595
                self.model_runner._determine_batch_execution_and_padding(
                    num_tokens=num_scheduled_tokens,
                    num_reqs=len(num_scheduled_tokens_np),
                    num_scheduled_tokens_np=num_scheduled_tokens_np,
                    max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
                    use_cascade_attn=False,  # TODO(lucas): Handle cascade attention
                )
596
            )
597
598
599
600
601
602
            all_gather_tensors = {
                "residual": not is_residual_scattered_for_sp(
                    self.vllm_config, batch_desc.num_tokens
                )
            }

603
        if forward_pass and not get_pp_group().is_first_rank:
604
605
606
            tensor_dict = get_pp_group().recv_tensor_dict(
                all_gather_group=get_tp_group(),
                all_gather_tensors=all_gather_tensors,
607
            )
608
609
            assert tensor_dict is not None
            intermediate_tensors = IntermediateTensors(tensor_dict)
610

611
612
613
614
615
616
        with self.annotate_profile(scheduler_output):
            output = self.model_runner.execute_model(
                scheduler_output, intermediate_tensors
            )
            if isinstance(output, (ModelRunnerOutput, NoneType)):
                return output
617

618
        assert isinstance(output, IntermediateTensors)
619
        parallel_config = self.vllm_config.parallel_config
620
        assert (
621
            parallel_config.distributed_executor_backend != "external_launcher"
622
623
            and not get_pp_group().is_last_rank
        )
624

625
626
627
628
629
        get_pp_group().send_tensor_dict(
            output.tensors,
            all_gather_group=get_tp_group(),
            all_gather_tensors=all_gather_tensors,
        )
630

631
        return None
632

633
    def take_draft_token_ids(self) -> DraftTokenIds | None:
634
635
        return self.model_runner.take_draft_token_ids()

636
    def profile(self, is_start: bool = True):
637
        if self.profiler is None:
638
            raise RuntimeError("Profiling is not enabled.")
639
640
641
642
643
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

644
    def execute_dummy_batch(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
645
646
647
648
649
650
        if self.use_v2_model_runner:
            self.model_runner.execute_model(
                SchedulerOutput.make_empty(), dummy_run=True
            )
        else:
            self.model_runner._dummy_run(1, uniform_decode=True)
651

652
653
654
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

655
656
657
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

658
    def list_loras(self) -> set[int]:
659
660
661
662
663
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

664
665
666
667
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

668
    def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
669
        from vllm.distributed.parallel_state import get_ep_group
670

671
        if get_ep_group().rank == 0:
672
673
674
            logger.info(
                "[Elastic EP] Starting expert resharding before scaling down..."
            )
675
676
677
678
679
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
680
681
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
682
            global_expert_loads=None,
683
684
            rank_mapping=rank_mapping,
        )
685
686
687
688
689
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
690
691
692
        self,
        old_ep_size: int,
        new_ep_size: int,
693
        global_expert_loads: list[torch.Tensor] | None,
694
    ) -> None:
695
        from vllm.distributed.parallel_state import get_ep_group
696

697
        if get_ep_group().rank == 0:
698
699
            logger.info("[Elastic EP] Starting expert resharding after scaling up...")
        rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
700
701
702
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            execute_shuffle=True,
703
            global_expert_loads=global_expert_loads,
704
705
            rank_mapping=rank_mapping,
        )
706
707
708
709
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
710
711
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
712
713
714
715
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
716
717
718
719
720
721
722
723
724
725
726
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
727
                reconfig_request.new_data_parallel_rank_local
728
729
            )
        parallel_config.data_parallel_master_ip = (
730
            reconfig_request.new_data_parallel_master_ip
731
732
        )
        parallel_config.data_parallel_master_port = (
733
            reconfig_request.new_data_parallel_master_port
734
        )
735

736
737
    def _reconfigure_moe(
        self, old_ep_size: int, new_ep_size: int
738
    ) -> list[torch.Tensor] | None:
739
740
741
742
743
744
745
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
746
747
748
749
            get_dp_group,
            get_ep_group,
            prepare_communication_buffer_for_model,
        )
750
751
752
753
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoE,
            FusedMoEParallelConfig,
        )
754
755

        parallel_config = self.vllm_config.parallel_config
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776

        def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
            return [
                module
                for module in model.modules()
                if (
                    module.__class__.__name__ == "FusedMoE"
                    or module.__class__.__name__ == "SharedFusedMoE"
                )
            ]

        def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
            assert all(
                module.moe_config.num_local_experts == num_local_experts
                for module in moe_modules
            ), "All MoE modules must have the same number of experts"
            for module in moe_modules:
                module.moe_config.num_experts = num_local_experts * new_ep_size
                module.global_num_experts = module.moe_config.num_experts
                module.moe_parallel_config = FusedMoEParallelConfig.make(
                    tp_size_=get_tp_group().world_size,
777
                    pcp_size_=get_pcp_group().world_size,
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
                    dp_size_=get_dp_group().world_size,
                    vllm_parallel_config=parallel_config,
                )
                module.moe_config.moe_parallel_config = module.moe_parallel_config
            return moe_modules

        model_moe_modules = get_moe_modules(self.model_runner.model)
        num_local_experts = model_moe_modules[0].moe_config.num_local_experts

        update_moe_modules(model_moe_modules, num_local_experts)
        drafter_model = None
        if hasattr(self.model_runner, "drafter") and hasattr(
            self.model_runner.drafter, "model"
        ):
            drafter_model = self.model_runner.drafter.model
        if drafter_model is not None and is_mixture_of_experts(drafter_model):
            drafter_moe_modules = get_moe_modules(drafter_model)
            # Check if drafter and model have matching configs
            assert (
                drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
            ), "Drafter and model configs should be the same"
            update_moe_modules(drafter_moe_modules, num_local_experts)

801
802
803
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
804
            new_physical_experts = (
805
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]  # type: ignore[attr-defined]
806
            )
807
            parallel_config.eplb_config.num_redundant_experts = (
808
                new_physical_experts
809
                - self.model_runner.eplb_state.logical_replica_count.shape[1]  # type: ignore[attr-defined]
810
            )
811
            global_expert_loads = None
812
        else:
813
            num_local_physical_experts_tensor = torch.tensor(
814
815
816
                [num_local_experts], dtype=torch.int32, device="cpu"
            )
            torch.distributed.broadcast(
817
818
819
                num_local_physical_experts_tensor,
                group=get_ep_group().cpu_group,
                group_src=0,
820
            )
821
            num_local_physical_experts = int(num_local_physical_experts_tensor.item())
822
823
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
824
            global_expert_loads_any = self.model_runner.eplb_state.rearrange(
825
                execute_shuffle=False
826
            )
827
            global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
828
            parallel_config.eplb_config.num_redundant_experts = (
829
                new_physical_experts - global_expert_loads[0].shape[1]
830
            )
831
        prepare_communication_buffer_for_model(self.model_runner.model)
832
833
        if drafter_model is not None:
            prepare_communication_buffer_for_model(drafter_model)
834
835
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
836
837
            num_local_physical_experts=num_local_physical_experts,
        )
838
        return global_expert_loads
839
840

    def reinitialize_distributed(
841
842
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
843
844
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
845
846
847
            cleanup_dist_env_and_memory,
            get_ep_group,
        )
848
849
850

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
851
852
853
854
855
        new_ep_size = (
            reconfig_request.new_data_parallel_size
            * get_tp_group().world_size
            * get_pp_group().world_size
        )
856
857
858
859
860
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

861
862
863
864
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
865
866
867
868
869
870
871
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
872
873
874
875
876
877
            init_worker_distributed_environment(
                self.vllm_config,
                self.rank,
                self.distributed_init_method,
                self.local_rank,
            )
878

879
        global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
880
881

        if new_ep_size > old_ep_size:
882
883
            assert global_expert_loads is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
884

885
886
887
    def save_sharded_state(
        self,
        path: str,
888
889
        pattern: str | None = None,
        max_size: int | None = None,
890
    ) -> None:
891
        from vllm.model_executor.model_loader import ShardedStateLoader
892

893
894
895
896
897
898
899
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

900
901
902
903
904
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
905
906
            tensorizer_config=tensorizer_config,
        )
907

908
    def shutdown(self) -> None:
909
910
        if runner := getattr(self, "model_runner", None):
            runner.ensure_kv_transfer_shutdown()
911
912
        if self.profiler is not None:
            self.profiler.shutdown()
913

914
915

def init_worker_distributed_environment(
916
    vllm_config: VllmConfig,
917
    rank: int,
918
    distributed_init_method: str | None = None,
919
    local_rank: int = -1,
920
    backend: str = "nccl",
921
922
) -> None:
    """Initialize the distributed environment."""
923
    parallel_config = vllm_config.parallel_config
924
925
926
    from vllm.model_executor.layers.batch_invariant import init_batch_invariance

    init_batch_invariance()
927
928
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

929
    init_method = distributed_init_method or "env://"
930
    init_distributed_environment(
931
        parallel_config.world_size, rank, init_method, local_rank, backend
932
    )
933

934
935
936
    ensure_model_parallel_initialized(
        parallel_config.tensor_parallel_size,
        parallel_config.pipeline_parallel_size,
937
        parallel_config.prefill_context_parallel_size,
938
939
        parallel_config.decode_context_parallel_size,
    )
940
941
942
943

    # Init ec connector here before KV caches caches init
    # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
    ensure_ec_transfer_initialized(vllm_config)