core.py 65.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
from vllm.tracing import instrument, maybe_init_worker_tracer
28
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
29
30
31
32
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
33
from vllm.utils.hashing import get_hash_fn_by_name
34
from vllm.utils.network_utils import make_zmq_socket
35
from vllm.utils.system_utils import decorate_logs, set_process_title
36
37
38
39
40
41
42
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
43
from vllm.v1.core.sched.interface import SchedulerInterface
44
from vllm.v1.core.sched.output import SchedulerOutput
45
from vllm.v1.engine import (
46
    EngineCoreOutput,
47
48
49
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
50
    FinishReason,
51
52
53
54
55
56
57
58
59
60
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
61
from vllm.v1.executor import Executor
62
from vllm.v1.kv_cache_interface import KVCacheConfig
63
from vllm.v1.metrics.stats import SchedulerStats
64
from vllm.v1.outputs import ModelRunnerOutput
65
from vllm.v1.request import Request, RequestStatus
66
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
67
from vllm.v1.structured_output import StructuredOutputManager
68
from vllm.v1.utils import compute_iteration_details
69
70
71
72
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

73
POLLING_TIMEOUT_S = 2.5
74
HANDSHAKE_TIMEOUT_MINS = 5
75

76
_R = TypeVar("_R")  # Return type for collective_rpc
77

78
79
80
81

class EngineCore:
    """Inner loop of vLLM's Engine."""

82
83
84
85
86
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
87
        executor_fail_callback: Callable | None = None,
88
        include_finished_set: bool = False,
89
    ):
90
91
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
92

93
94
        load_general_plugins()

95
        self.vllm_config = vllm_config
96
        if not vllm_config.parallel_config.data_parallel_rank_local:
97
98
99
100
101
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
102

103
104
        self.log_stats = log_stats

105
106
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
107
        if executor_fail_callback is not None:
108
            self.model_executor.register_failure_callback(executor_fail_callback)
109

110
111
        self.available_gpu_memory_for_kv_cache = -1

112
        # Setup KV Caches and update CacheConfig after profiling.
113
114
115
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
116

117
118
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
119
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
120

121
122
        self.structured_output_manager = StructuredOutputManager(vllm_config)

123
        # Setup scheduler.
124
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
125

126
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
127
128
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
129
130
131
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
132

133
134
135
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
136
            * vllm_config.parallel_config.prefill_context_parallel_size
137
138
        )

139
        self.scheduler: SchedulerInterface = Scheduler(
140
            vllm_config=vllm_config,
141
142
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
143
            include_finished_set=include_finished_set,
144
            log_stats=self.log_stats,
145
            block_size=scheduler_block_size,
146
        )
147
        self.use_spec_decode = vllm_config.speculative_config is not None
148
        if self.scheduler.connector is not None:  # type: ignore
149
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
150

151
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
152
153
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
154
        )
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

177
178
179
180
181
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
182
        self.batch_queue: (
183
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
184
        ) = None
185
        if self.batch_queue_size > 1:
186
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
187
            self.batch_queue = deque(maxlen=self.batch_queue_size)
188

189
        self.is_ec_producer = (
190
191
192
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
193
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
194

195
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
196
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
197
            caching_hash_fn = get_hash_fn_by_name(
198
199
                vllm_config.cache_config.prefix_caching_hash_algo
            )
200
201
202
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
203
                scheduler_block_size, caching_hash_fn
204
            )
205

206
207
208
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
209
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
210

211
        self.aborts_queue = queue.Queue[list[str]]()
212
213
214
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
215
216
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
217
218
219
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
220

221
    @instrument(span_name="Prepare model")
222
    def _initialize_kv_caches(
223
224
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
225
        start = time.time()
226

227
        # Get all kv cache needed by the model
228
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
229

230
231
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
232
233
234
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
235
                self.available_gpu_memory_for_kv_cache = (
236
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
237
238
239
240
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
241
242
243
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
244
245
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
246
247
248
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
249

250
        assert len(kv_cache_specs) == len(available_gpu_memory)
251

252
253
254
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

255
256
257
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
258
259
260
261
262
263
264
265

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

266
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
267
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
268
        num_cpu_blocks = 0
269
270

        # Initialize kv cache and warmup the execution
271
        self.model_executor.initialize_from_config(kv_cache_configs)
272

273
        elapsed = time.time() - start
274
        logger.info_once(
275
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
276
            elapsed,
277
            scope="local",
278
        )
279
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
280

281
282
283
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

284
285
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
286

287
288
289
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
290
291
292
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
293
294
                f"request_id must be a string, got {type(request.request_id)}"
            )
295

296
        if pooling_params := request.pooling_params:
297
            supported_pooling_tasks = [
298
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
299
300
            ]

301
            if pooling_params.task not in supported_pooling_tasks:
302
303
304
305
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
306

307
        if request.kv_transfer_params is not None and (
308
309
310
311
312
313
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
314

315
        self.scheduler.add_request(request)
316

317
    def abort_requests(self, request_ids: list[str]):
318
319
320
321
322
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
323
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
324

325
326
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
327
        """Execute the model and log detailed info on failure."""
328
        try:
329
            yield
330
331
332
333
334
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

335
            # NOTE: This method is exception-free
336
337
338
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
339
340
            raise err

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

371
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
372
373
374
375
376
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
377

378
379
380
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
381
            return {}, False
382
383
384
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
385
386
387
388
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
389
390
391
392
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

393
394
395
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
396
397
398
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
399

400
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
401

402
    def post_step(self, model_executed: bool) -> None:
403
404
405
406
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
407
408
409
410
411
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

412
    def step_with_batch_queue(
413
        self,
414
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
415
416
417
418
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
419
420
421
422
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
423
424
425
426
427
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
428
429
        batch_queue = self.batch_queue
        assert batch_queue is not None
430

431
432
433
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
434
        assert len(batch_queue) < self.batch_queue_size
435

436
        model_executed = False
437
        deferred_scheduler_output = None
438
        if self.scheduler.has_requests():
439
440
441
442
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
443
            if not self.is_ec_producer:
444
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
445

446
            if self.is_pooling_model or not model_executed:
447
448
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
449
            else:
450
451
452
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
453
454
455
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
456
457
458
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
459
                else:
460
461
462
463
464
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
465
                # Add this step's future to the queue.
466
                batch_queue.appendleft((future, scheduler_output, exec_future))
467
468
469
470
471
472
473
474
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
475
476
477
478
479
480

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
481
482

        # Block until the next result is available.
483
        future, scheduler_output, exec_model_fut = batch_queue.pop()
484
485
486
487
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
488
            model_output = future.result()
489
490
491
492
493
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
494

495
496
497
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
498
499
500
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
501
502
503
504
505

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
506
507
508
509
510
511
512
513
514
515
516
517
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
518
519
520
521
522
523
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
524
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
525

526
        return engine_core_outputs, model_executed
527

528
529
530
531
532
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
533
534
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
535
536
537
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

538
    def shutdown(self):
539
        self.structured_output_manager.clear_backend()
540
541
        if self.model_executor:
            self.model_executor.shutdown()
542
543
        if self.scheduler:
            self.scheduler.shutdown()
544

545
    def profile(self, is_start: bool = True):
546
        self.model_executor.profile(is_start)
547

548
549
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
550
        # re-sync the internal caches (P0 sender, P1 receiver)
551
        if self.scheduler.has_unfinished_requests():
552
553
554
555
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
556

557
        # The cache either exists in EngineCore or WorkerWrapperBase
558
559
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
560

561
562
        self.model_executor.reset_mm_cache()

563
564
565
566
567
568
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        Clears both the scheduler's cache manager and the GPU model runner's cache.
        """
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 sender, P1 receiver)
        if self.scheduler.has_unfinished_requests():
            logger.warning(
                "Resetting the encoder cache when requests are "
                "in progress may lead to desynced internal caches."
            )

        # Reset the scheduler's encoder cache manager (logical state)
        self.scheduler.reset_encoder_cache()
        # Reset the GPU model runner's encoder cache (physical storage)
        self.model_executor.reset_encoder_cache()

590
591
592
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

593
    def wake_up(self, tags: list[str] | None = None):
594
        self.model_executor.wake_up(tags)
595

596
597
598
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

599
    def execute_dummy_batch(self):
600
        self.model_executor.execute_dummy_batch()
601

602
603
604
605
606
607
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

608
    def list_loras(self) -> set[int]:
609
610
611
612
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
613

614
615
616
    def save_sharded_state(
        self,
        path: str,
617
618
        pattern: str | None = None,
        max_size: int | None = None,
619
    ) -> None:
620
621
622
623
624
625
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
626
627
        method: str | Callable[..., _R],
        timeout: float | None = None,
628
        args: tuple = (),
629
        kwargs: dict[str, Any] | None = None,
630
631
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
632

633
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
634
        """Preprocess the request.
635

636
637
638
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
639
640
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
641
        # and will only be accessed in the input processing thread afterwards.
642
        if self.mm_receiver_cache is not None and request.mm_features:
643
644
645
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
646

647
        req = Request.from_engine_core_request(request, self.request_block_hasher)
648
649
650
651
652
653
654
655
656
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

657
658
659
660

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

661
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
662

663
    @instrument(span_name="EngineCoreProc init")
664
665
    def __init__(
        self,
666
        vllm_config: VllmConfig,
667
        local_client: bool,
668
        handshake_address: str,
669
        executor_class: type[Executor],
670
        log_stats: bool,
671
        client_handshake_address: str | None = None,
672
        *,
673
        engine_index: int = 0,
674
    ):
Rui Qiao's avatar
Rui Qiao committed
675
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
676
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
677
        executor_fail_callback = lambda: self.input_queue.put_nowait(
678
679
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
680

Rui Qiao's avatar
Rui Qiao committed
681
682
683
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
684

685
686
687
688
689
690
691
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
692
            self.client_count = len(addresses.outputs)
693
694

            # Set up data parallel environment.
695
            self.has_coordinator = addresses.coordinator_output is not None
696
            self.frontend_stats_publish_address = (
697
698
699
700
701
702
703
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
704
            internal_dp_balancing = (
705
                self.has_coordinator
706
707
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
708
709
710
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
711

712
713
            self._init_data_parallel(vllm_config)

714
            super().__init__(
715
716
717
718
719
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
720
            )
721

722
723
724
725
726
727
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
728
729
730
731
732
733
734
735
736
737
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
738
739
740
741
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
742
743
744
745
746
747
748
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
749
750
751
752
753
754
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
755
                    raise RuntimeError("Input socket thread died during startup")
756
757
758
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
759
    @contextmanager
760
761
762
763
764
765
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
766
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
767
    ) -> Generator[EngineZmqAddresses, None, None]:
768
769
770
771
772
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

773
        For DP>1 with internal load-balancing this is with the shared front-end
774
775
        process which may reside on a different node.

776
        For DP>1 with external or hybrid load-balancing, two handshakes are
777
        performed:
778
779
780
781
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
782
783
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
784
785
786
787
788
789

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
790
        input_ctx = zmq.Context()
791
        is_local = local_client and client_handshake_address is None
792
        headless = not local_client
793
794
795
796
797
798
799
800
801
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
802
803
804
805
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
806
            assert local_client
807
            local_handshake = self._perform_handshake(
808
809
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
825
        headless: bool,
826
        vllm_config: VllmConfig,
827
        parallel_config_to_update: ParallelConfig | None = None,
828
    ) -> Generator[EngineZmqAddresses, None, None]:
829
830
831
832
833
834
835
836
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
837
            # Register engine with front-end.
838
839
840
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
841
842
843
844
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
845
846
847
848
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
849
850
851
852
853
854
855
856
857
858
859
860

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
861
                )
862
863

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
864

865
    @staticmethod
866
    def startup_handshake(
867
868
        handshake_socket: zmq.Socket,
        local_client: bool,
869
        headless: bool,
870
        parallel_config: ParallelConfig | None = None,
871
    ) -> EngineZmqAddresses:
872
        # Send registration message.
873
        handshake_socket.send(
874
875
876
877
878
879
880
881
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
882
883

        # Receive initialization message.
884
        logger.debug("Waiting for init message from front-end.")
885
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
886
887
888
889
890
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
891
892
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
893
894
            init_bytes, type=EngineHandshakeMetadata
        )
895
896
        logger.debug("Received init message: %s", init_message)

897
898
899
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
900

901
        return init_message.addresses
902
903

    @staticmethod
904
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
905
906
        """Launch EngineCore busy loop in background process."""

907
908
909
910
911
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

912
913
914
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

915
916
917
918
919
920
921
922
923
924
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

925
        engine_core: EngineCoreProc | None = None
926
        try:
927
928
929
930
931
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
932
933
934
935
936
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name=f"EngineCore_DP{dp_rank}",
                )
937
                set_process_title("EngineCore", f"DP{dp_rank}")
938
            else:
939
940
941
942
943
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name="EngineCore",
                )
944
945
946
                set_process_title("EngineCore")
            decorate_logs()

947
948
949
950
951
952
953
954
955
956
957
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

958
959
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
960
961
962
963
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
964
965
966
967
968
969
970
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
971

972
            assert engine_core is not None
973
974
            engine_core.run_busy_loop()

975
        except SystemExit:
976
            logger.debug("EngineCore exiting.")
977
            raise
978
979
980
981
982
983
984
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
985
986
987
988
        finally:
            if engine_core is not None:
                engine_core.shutdown()

989
990
991
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

992
993
994
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

995
996
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
997
            # 1) Poll the input queue until there is work to do.
998
999
1000
1001
1002
1003
1004
1005
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
1006
1007
1008
1009
1010
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
1011
1012
1013
1014
1015
1016
1017
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
1018
1019
1020
1021
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
1022
            logger.debug("EngineCore loop active.")
1023
1024
1025
1026
1027
1028

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

1029
    def _process_engine_step(self) -> bool:
1030
1031
1032
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1033
        outputs, model_executed = self.step_fn()
1034
        # Put EngineCoreOutputs into the output queue.
1035
        for output in outputs.items() if outputs else ():
1036
            self.output_queue.put_nowait(output)
1037
1038
        # Post-step hook.
        self.post_step(model_executed)
1039

1040
1041
1042
1043
1044
1045
1046
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1047
1048
        return model_executed

1049
1050
1051
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1052
        """Dispatch request from client."""
1053

1054
        if request_type == EngineCoreRequestType.ADD:
1055
1056
            req, request_wave = request
            self.add_request(req, request_wave)
1057
        elif request_type == EngineCoreRequestType.ABORT:
1058
            self.abort_requests(request)
1059
        elif request_type == EngineCoreRequestType.UTILITY:
1060
            client_idx, call_id, method_name, args = request
1061
1062
1063
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
1064
1065
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
1066
1067
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
1068
1069
1070
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
1071
            self.output_queue.put_nowait(
1072
1073
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
1074
1075
1076
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1077
1078
1079
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1080
1081
1082
1083

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1084
        arg type, try converting to msgspec object."""
1085
1086
1087
1088
1089
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1090
1091
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1092
            and issubclass(p.annotation, msgspec.Struct)
1093
1094
1095
1096
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1097

1098
1099
1100
1101
1102
1103
1104
1105
1106
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1107
1108
1109
1110
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1111

1112
1113
1114
    def process_input_sockets(
        self,
        input_addresses: list[str],
1115
        coord_input_address: str | None,
1116
1117
1118
        identity: bytes,
        ready_event: threading.Event,
    ):
1119
1120
1121
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1122
1123
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1124

1125
1126
1127
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1128
1129
1130
1131
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1132
1133
1134
1135
1136
1137
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1138
1139
1140
1141
1142
1143
1144
1145
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1146
                # Send subscription message to coordinator.
1147
                coord_socket.send(b"\x01")
1148
1149
1150
1151
1152
1153
1154

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1155
                input_socket.send(b"")
1156
                poller.register(input_socket, zmq.POLLIN)
1157

1158
            if coord_socket is not None:
1159
1160
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1161
                poller.register(coord_socket, zmq.POLLIN)
1162

1163
1164
            ready_event.set()
            del ready_event
1165
1166
1167
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1168
1169
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1170
1171

                    # Deserialize the request data.
1172
                    request: Any
1173
                    if request_type == EngineCoreRequestType.ADD:
1174
1175
1176
1177
1178
1179
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1180
1181
                    else:
                        request = generic_decoder.decode(data_frames)
1182

1183
1184
1185
1186
1187
1188
1189
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1190
1191
1192
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1193
1194
1195
    def process_output_sockets(
        self,
        output_paths: list[str],
1196
        coord_output_path: str | None,
1197
1198
        engine_index: int,
    ):
1199
1200
1201
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1202
        encoder = MsgpackEncoder()
1203
1204
1205
1206
1207
1208
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1209

1210
1211
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1212
1213
1214
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1215
1216
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1217
1218
                for output_path in output_paths
            ]
1219
1220
1221
1222
1223
1224
1225
1226
1227
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1228
1229
            max_reuse_bufs = len(sockets) + 1

1230
            while True:
1231
1232
1233
1234
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1235
                    break
1236
1237
                assert not isinstance(output, bytes)
                client_index, outputs = output
1238
                outputs.engine_index = engine_index
1239

1240
1241
1242
1243
1244
1245
1246
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1247
1248
1249
1250
1251
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1252
                buffers = encoder.encode_into(outputs, buffer)
1253
1254
1255
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1256
1257
1258
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1259
1260
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1261
                    reuse_buffers.append(buffer)
1262

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1287
1288
1289
1290
1291
1292
1293
1294

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1295
        local_client: bool,
1296
        handshake_address: str,
1297
1298
        executor_class: type[Executor],
        log_stats: bool,
1299
        client_handshake_address: str | None = None,
1300
    ):
1301
1302
1303
1304
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1305
1306
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1307
        self.step_counter = 0
1308
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1309
        self.last_counts = (0, 0)
1310
1311
1312

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1313
1314
1315
1316
1317
1318
1319
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1320
            engine_index=dp_rank,
1321
        )
1322
1323
1324

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1325
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1326
        dp_size = vllm_config.parallel_config.data_parallel_size
1327
1328
1329
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1330
        assert local_dp_rank is not None
1331
1332
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1333
        self.dp_rank = dp_rank
1334
1335
1336
1337
1338
1339
1340
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1341
1342
1343
1344
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1345
1346
1347
1348
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1349
1350
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1351

1352
        super().add_request(request, request_wave)
1353

1354
1355
1356
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1357
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1358
1359
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1360
1361
                new_wave >= self.current_wave
            ):
1362
1363
                self.current_wave = new_wave
                if not self.engines_running:
1364
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1365
1366
1367
1368
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1369
    def _maybe_publish_request_counts(self):
1370
        if not self.publish_dp_lb_stats:
1371
1372
1373
1374
1375
1376
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1377
1378
1379
1380
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1381

1382
1383
1384
1385
1386
1387
1388
1389
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1390
1391
            # 2) Step the engine core.
            executed = self._process_engine_step()
1392
1393
            self._maybe_publish_request_counts()

1394
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1395
1396
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1397
1398
1399
                    # All engines are idle.
                    continue

1400
1401
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1402
1403
1404
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1405
            self.engines_running = self._has_global_unfinished_reqs(
1406
1407
                local_unfinished_reqs
            )
1408

1409
            if not self.engines_running:
1410
                if self.dp_rank == 0 or not self.has_coordinator:
1411
                    # Notify client that we are pausing the loop.
1412
1413
1414
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1415
1416
1417
1418
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1419
                    self.output_queue.put_nowait(
1420
1421
1422
1423
1424
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1425
                # Increment wave count and reset step counter.
1426
                self.current_wave += 1
1427
                self.step_counter = 0
1428
1429

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1430
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1431
1432
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1433
1434
            return True

1435
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1436

1437
    def reinitialize_distributed(
1438
1439
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1440
1441
1442
1443
1444
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1445
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1446
        if reconfig_request.new_data_parallel_rank != -1:
1447
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1448
        # local rank specifies device visibility, it should not be changed
1449
1450
1451
1452
1453
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1454
            reconfig_request.new_data_parallel_master_ip
1455
1456
        )
        parallel_config.data_parallel_master_port = (
1457
            reconfig_request.new_data_parallel_master_port
1458
        )
1459
1460
1461
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1462
        reconfig_request.new_data_parallel_master_port = (
1463
            parallel_config.data_parallel_master_port
1464
        )
1465
1466
1467
1468
1469
1470
1471
1472

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1473
1474
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1475
1476
1477
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1478
1479
1480
1481
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1482
1483
1484
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1485
1486
1487
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1488

Rui Qiao's avatar
Rui Qiao committed
1489

1490
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
1502
1503
1504
1505
1506
1507
1508
        # Initialize tracer for distributed tracing if configured.
        maybe_init_worker_tracer(
            instrumenting_module_name="vllm.engine_core",
            process_kind="engine_core",
            process_name=f"DPEngineCoreActor_DP{dp_rank}",
        )

Rui Qiao's avatar
Rui Qiao committed
1509
        self.addresses = addresses
1510
        vllm_config.parallel_config.data_parallel_index = dp_rank
1511
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1512

1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1523
1524
1525
1526
1527
1528
1529
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1530
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1531

1532
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1533
        from vllm.platforms import current_platform
1534

1535
1536
1537
1538
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1539
1540
1541
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1542

1543
1544
1545
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1546
1547
1548
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1549
1550
1551
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1552
            os.environ[device_control_env_var] = value
1553
1554
1555
1556
1557
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1558
1559
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1560

Rui Qiao's avatar
Rui Qiao committed
1561
    @contextmanager
1562
1563
1564
1565
1566
1567
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1568
        client_handshake_address: str | None,
1569
    ):
Rui Qiao's avatar
Rui Qiao committed
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1592
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1593
1594
1595
1596
1597
1598
1599
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )