core.py 63.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
28
29
30
31
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
32
from vllm.utils.hashing import get_hash_fn_by_name
33
from vllm.utils.network_utils import make_zmq_socket
34
from vllm.utils.system_utils import decorate_logs, set_process_title
35
36
37
38
39
40
41
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
42
from vllm.v1.core.sched.interface import SchedulerInterface
43
from vllm.v1.core.sched.output import SchedulerOutput
44
from vllm.v1.engine import (
45
    EngineCoreOutput,
46
47
48
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
49
    FinishReason,
50
51
52
53
54
55
56
57
58
59
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
60
from vllm.v1.executor import Executor
61
from vllm.v1.kv_cache_interface import KVCacheConfig
62
from vllm.v1.metrics.stats import SchedulerStats
63
from vllm.v1.outputs import ModelRunnerOutput
64
from vllm.v1.request import Request, RequestStatus
65
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
66
from vllm.v1.structured_output import StructuredOutputManager
67
from vllm.v1.utils import compute_iteration_details
68
from vllm.version import __version__ as VLLM_VERSION
xuxz's avatar
xuxz committed
69
from vllm import envs
70
71
72

logger = init_logger(__name__)

73
POLLING_TIMEOUT_S = 2.5
74
HANDSHAKE_TIMEOUT_MINS = 5
75

76
_R = TypeVar("_R")  # Return type for collective_rpc
77

78
79
80
81

class EngineCore:
    """Inner loop of vLLM's Engine."""

82
83
84
85
86
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
87
        executor_fail_callback: Callable | None = None,
88
        include_finished_set: bool = False,
89
    ):
90
91
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
92

93
94
        load_general_plugins()

95
        self.vllm_config = vllm_config
96
        if not vllm_config.parallel_config.data_parallel_rank_local:
97
98
99
100
101
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
102

103
104
        self.log_stats = log_stats

105
106
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
107
        if executor_fail_callback is not None:
108
            self.model_executor.register_failure_callback(executor_fail_callback)
109

110
111
        self.available_gpu_memory_for_kv_cache = -1

112
        # Setup KV Caches and update CacheConfig after profiling.
113
114
115
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
116

117
118
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
119
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
120

121
122
        self.structured_output_manager = StructuredOutputManager(vllm_config)

123
        # Setup scheduler.
124
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
125

126
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
127
128
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
129
130
131
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
132

133
134
135
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
136
            * vllm_config.parallel_config.prefill_context_parallel_size
137
138
        )

139
        self.scheduler: SchedulerInterface = Scheduler(
140
            vllm_config=vllm_config,
141
142
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
143
            include_finished_set=include_finished_set,
144
            log_stats=self.log_stats,
145
            block_size=scheduler_block_size,
146
        )
147
        self.use_spec_decode = vllm_config.speculative_config is not None
148
        if self.scheduler.connector is not None:  # type: ignore
149
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
150

151
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
152
153
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
154
        )
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

177
178
179
180
181
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
182
        self.batch_queue: (
183
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
184
        ) = None
185
        if self.batch_queue_size > 1:
186
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
187
            self.batch_queue = deque(maxlen=self.batch_queue_size)
188

189
        self.is_ec_producer = (
190
191
192
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
193
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
194

195
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
196
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
197
            caching_hash_fn = get_hash_fn_by_name(
198
199
                vllm_config.cache_config.prefix_caching_hash_algo
            )
200
201
202
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
203
                scheduler_block_size, caching_hash_fn
204
            )
205

206
207
208
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
209
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
210

211
        self.aborts_queue = queue.Queue[list[str]]()
212
213
214
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
215
216
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
217
218
219
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
220

221
    def _initialize_kv_caches(
222
223
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
224
        start = time.time()
225

226
        # Get all kv cache needed by the model
227
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
228

229
230
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
231
232
233
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
234
                self.available_gpu_memory_for_kv_cache = (
235
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
236
237
238
239
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
240
241
242
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
243
244
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
245
246
247
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
248

249
        assert len(kv_cache_specs) == len(available_gpu_memory)
250

251
252
253
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

254
255
256
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
257
258
259
260
261
262
263
264

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

265
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
266
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
267
        num_cpu_blocks = 0
268
269

        # Initialize kv cache and warmup the execution
270
        self.model_executor.initialize_from_config(kv_cache_configs)
271

272
        elapsed = time.time() - start
273
        logger.info_once(
274
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
275
            elapsed,
276
            scope="local",
277
        )
278
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
279

280
281
282
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

283
284
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
285

286
287
288
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
289
290
291
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
292
293
                f"request_id must be a string, got {type(request.request_id)}"
            )
294

295
        if pooling_params := request.pooling_params:
296
            supported_pooling_tasks = [
297
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
298
299
            ]

300
            if pooling_params.task not in supported_pooling_tasks:
301
302
303
304
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
305

306
        if request.kv_transfer_params is not None and (
307
308
309
310
311
312
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
313

314
        self.scheduler.add_request(request)
315

316
    def abort_requests(self, request_ids: list[str]):
317
318
319
320
321
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
322
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
323

324
325
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
326
        """Execute the model and log detailed info on failure."""
327
        try:
328
            yield
329
330
331
332
333
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

334
            # NOTE: This method is exception-free
335
336
337
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
338
339
            raise err

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

370
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
371
372
373
374
375
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
376

377
378
379
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
380
            return {}, False
381
382
383
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
384
385
386
387
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
388
389
390
391
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

392
393
394
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
395
396
397
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
398

399
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
400

401
    def post_step(self, model_executed: bool) -> None:
402
403
404
405
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
406
407
408
409
410
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

411
    def step_with_batch_queue(
412
        self,
413
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
414
415
416
417
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
418
419
420
421
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
422
423
424
425
426
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
427
428
        batch_queue = self.batch_queue
        assert batch_queue is not None
429

430
431
432
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
433
        assert len(batch_queue) < self.batch_queue_size
434

435
        model_executed = False
436
        deferred_scheduler_output = None
437
        if self.scheduler.has_requests():
438
439
440
441
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
442
            if not self.is_ec_producer:
443
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
444

445
            if self.is_pooling_model or not model_executed:
446
447
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
448
            else:
449
450
451
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
452
453
454
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
455
456
457
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
458
                else:
459
460
461
462
463
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
464
                # Add this step's future to the queue.
465
                batch_queue.appendleft((future, scheduler_output, exec_future))
466
467
468
469
470
471
472
473
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
474
475
476
477
478
479

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
480
481

        # Block until the next result is available.
482
        future, scheduler_output, exec_model_fut = batch_queue.pop()
483
484
485
486
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
487
            model_output = future.result()
488
489
490
491
492
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
493

494
495
496
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
497
498
499
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
500
501
502
503
504

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
505
506
507
508
509
510
511
512
513
514
515
516
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
517
518
519
520
521
522
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
523
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
524

525
        return engine_core_outputs, model_executed
526

527
528
529
530
531
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
532
533
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
534
535
536
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

537
    def shutdown(self):
538
        self.structured_output_manager.clear_backend()
539
540
        if self.model_executor:
            self.model_executor.shutdown()
541
542
        if self.scheduler:
            self.scheduler.shutdown()
543

544
    def profile(self, is_start: bool = True):
545
        self.model_executor.profile(is_start)
546

547
548
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
549
        # re-sync the internal caches (P0 sender, P1 receiver)
550
        if self.scheduler.has_unfinished_requests():
551
552
553
554
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
555

556
        # The cache either exists in EngineCore or WorkerWrapperBase
557
558
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
559

560
561
        self.model_executor.reset_mm_cache()

562
563
564
565
566
567
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
568

569
570
571
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

572
    def wake_up(self, tags: list[str] | None = None):
573
        self.model_executor.wake_up(tags)
574

575
576
577
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

578
    def execute_dummy_batch(self):
579
        self.model_executor.execute_dummy_batch()
580

581
582
583
584
585
586
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

587
    def list_loras(self) -> set[int]:
588
589
590
591
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
592

593
594
595
    def save_sharded_state(
        self,
        path: str,
596
597
        pattern: str | None = None,
        max_size: int | None = None,
598
    ) -> None:
599
600
601
602
603
604
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
605
606
        method: str | Callable[..., _R],
        timeout: float | None = None,
607
        args: tuple = (),
608
        kwargs: dict[str, Any] | None = None,
609
610
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
611

612
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
613
        """Preprocess the request.
614

615
616
617
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
618
619
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
620
        # and will only be accessed in the input processing thread afterwards.
621
        if self.mm_receiver_cache is not None and request.mm_features:
622
623
624
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
625

626
        req = Request.from_engine_core_request(request, self.request_block_hasher)
627
628
629
630
631
632
633
634
635
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

636
637
638
639

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

640
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
641

642
643
    def __init__(
        self,
644
        vllm_config: VllmConfig,
645
        local_client: bool,
646
        handshake_address: str,
647
        executor_class: type[Executor],
648
        log_stats: bool,
649
        client_handshake_address: str | None = None,
650
        *,
651
        engine_index: int = 0,
652
    ):
Rui Qiao's avatar
Rui Qiao committed
653
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
654
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
655
        executor_fail_callback = lambda: self.input_queue.put_nowait(
656
657
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
658

Rui Qiao's avatar
Rui Qiao committed
659
660
661
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
662

663
664
665
666
667
668
669
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
670
            self.client_count = len(addresses.outputs)
671
672

            # Set up data parallel environment.
673
            self.has_coordinator = addresses.coordinator_output is not None
674
            self.frontend_stats_publish_address = (
675
676
677
678
679
680
681
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
682
            internal_dp_balancing = (
683
                self.has_coordinator
684
685
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
686
687
688
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
689

690
691
            self._init_data_parallel(vllm_config)

692
            super().__init__(
693
694
695
696
697
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
698
            )
699

700
701
702
703
704
705
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
706
707
708
709
710
711
712
713
714
715
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
716
717
718
719
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
720
721
722
723
724
725
726
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
727
728
729
730
731
732
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
733
                    raise RuntimeError("Input socket thread died during startup")
734
735
736
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
737
    @contextmanager
738
739
740
741
742
743
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
744
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
745
    ) -> Generator[EngineZmqAddresses, None, None]:
746
747
748
749
750
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

751
        For DP>1 with internal load-balancing this is with the shared front-end
752
753
        process which may reside on a different node.

754
        For DP>1 with external or hybrid load-balancing, two handshakes are
755
        performed:
756
757
758
759
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
760
761
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
762
763
764
765
766
767

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
768
        input_ctx = zmq.Context()
769
        is_local = local_client and client_handshake_address is None
770
        headless = not local_client
771
772
773
774
775
776
777
778
779
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
780
781
782
783
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
784
            assert local_client
785
            local_handshake = self._perform_handshake(
786
787
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
803
        headless: bool,
804
        vllm_config: VllmConfig,
805
        parallel_config_to_update: ParallelConfig | None = None,
806
    ) -> Generator[EngineZmqAddresses, None, None]:
807
808
809
810
811
812
813
814
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
815
            # Register engine with front-end.
816
817
818
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
819
820
821
822
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
823
824
825
826
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
827
828
829
830
831
832
833
834
835
836
837
838

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
839
                )
840
841

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
842

843
    @staticmethod
844
    def startup_handshake(
845
846
        handshake_socket: zmq.Socket,
        local_client: bool,
847
        headless: bool,
848
        parallel_config: ParallelConfig | None = None,
849
    ) -> EngineZmqAddresses:
850
        # Send registration message.
851
        handshake_socket.send(
852
853
854
855
856
857
858
859
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
860
861

        # Receive initialization message.
862
        logger.debug("Waiting for init message from front-end.")
863
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
864
865
866
867
868
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
869
870
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
871
872
            init_bytes, type=EngineHandshakeMetadata
        )
873
874
        logger.debug("Received init message: %s", init_message)

875
876
877
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
878

879
        return init_message.addresses
880
881

    @staticmethod
882
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
883
884
        """Launch EngineCore busy loop in background process."""

885
886
887
888
889
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

890
891
892
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

893
894
895
896
897
898
899
900
901
902
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

903
        engine_core: EngineCoreProc | None = None
904
        try:
905
906
907
908
909
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
910
                set_process_title("EngineCore", f"DP{dp_rank}")
911
912
913
914
            else:
                set_process_title("EngineCore")
            decorate_logs()

915
916
917
918
919
920
921
922
923
924
925
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

926
927
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
928
929
930
931
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
932
933
934
935
936
937
938
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
939

940
941
            engine_core.run_busy_loop()

942
        except SystemExit:
943
            logger.debug("EngineCore exiting.")
944
            raise
945
946
947
948
949
950
951
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
952
953
954
955
        finally:
            if engine_core is not None:
                engine_core.shutdown()

956
957
958
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

959
960
961
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

962
963
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
964
            # 1) Poll the input queue until there is work to do.
965
966
967
968
969
970
971
972
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
973
974
975
976
977
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
978
979
980
981
982
983
984
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
985
986
987
988
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
989
            logger.debug("EngineCore loop active.")
990
991
992
993
994
995

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

996
    def _process_engine_step(self) -> bool:
997
998
999
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1000
        outputs, model_executed = self.step_fn()
1001
        # Put EngineCoreOutputs into the output queue.
1002
        for output in outputs.items() if outputs else ():
1003
            self.output_queue.put_nowait(output)
1004
        # Post-step hook.
1005
        #self.post_step(model_executed)
1006

1007
1008
1009
1010
1011
1012
1013
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1014
1015
        return model_executed

1016
1017
1018
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1019
        """Dispatch request from client."""
1020

1021
        if request_type == EngineCoreRequestType.ADD:
1022
1023
            req, request_wave = request
            self.add_request(req, request_wave)
1024
        elif request_type == EngineCoreRequestType.ABORT:
1025
            self.abort_requests(request)
1026
        elif request_type == EngineCoreRequestType.UTILITY:
1027
            client_idx, call_id, method_name, args = request
1028
1029
1030
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
1031
1032
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
1033
1034
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
1035
1036
1037
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
1038
            self.output_queue.put_nowait(
1039
1040
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
1041
1042
1043
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1044
1045
1046
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1047
1048
1049
1050

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1051
        arg type, try converting to msgspec object."""
1052
1053
1054
1055
1056
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1057
1058
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1059
            and issubclass(p.annotation, msgspec.Struct)
1060
1061
1062
1063
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1074
1075
1076
1077
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1078

1079
1080
1081
    def process_input_sockets(
        self,
        input_addresses: list[str],
1082
        coord_input_address: str | None,
1083
1084
1085
        identity: bytes,
        ready_event: threading.Event,
    ):
1086
1087
1088
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1089
1090
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1091

1092
1093
1094
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1095
1096
1097
1098
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1099
1100
1101
1102
1103
1104
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1105
1106
1107
1108
1109
1110
1111
1112
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1113
                # Send subscription message to coordinator.
1114
                coord_socket.send(b"\x01")
1115
1116
1117
1118
1119
1120
1121

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1122
                input_socket.send(b"")
1123
                poller.register(input_socket, zmq.POLLIN)
1124

1125
            if coord_socket is not None:
1126
1127
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1128
                poller.register(coord_socket, zmq.POLLIN)
1129

1130
1131
            ready_event.set()
            del ready_event
1132
1133
1134
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1135
1136
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1137
1138

                    # Deserialize the request data.
1139
                    request: Any
1140
                    if request_type == EngineCoreRequestType.ADD:
1141
1142
1143
1144
1145
1146
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1147
1148
                    else:
                        request = generic_decoder.decode(data_frames)
1149

1150
1151
1152
1153
1154
1155
1156
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1157
1158
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))
xuxz's avatar
xuxz committed
1159
1160
1161
1162
1163
                    if isinstance(request, tuple) and self.scheduler.connector is not None \
                        and envs.VLLM_USE_DP_CONNECTOR:
                        req, _ = request
                        if request_type == EngineCoreRequestType.ADD:
                            self.scheduler.connector.register_req(req.request_id)
1164

1165
1166
1167
    def process_output_sockets(
        self,
        output_paths: list[str],
1168
        coord_output_path: str | None,
1169
1170
        engine_index: int,
    ):
1171
1172
1173
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1174
        encoder = MsgpackEncoder()
1175
1176
1177
1178
1179
1180
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1181

1182
1183
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1184
1185
1186
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1187
1188
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1189
1190
                for output_path in output_paths
            ]
1191
1192
1193
1194
1195
1196
1197
1198
1199
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1200
1201
            max_reuse_bufs = len(sockets) + 1

1202
            while True:
1203
1204
1205
1206
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1207
                    break
1208
1209
                assert not isinstance(output, bytes)
                client_index, outputs = output
1210
                outputs.engine_index = engine_index
1211

1212
1213
1214
1215
1216
1217
1218
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1219
1220
1221
1222
1223
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1224
                buffers = encoder.encode_into(outputs, buffer)
1225
1226
1227
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1228
1229
1230
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1231
1232
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1233
                    reuse_buffers.append(buffer)
1234

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1259
1260
1261
1262
1263
1264
1265
1266

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1267
        local_client: bool,
1268
        handshake_address: str,
1269
1270
        executor_class: type[Executor],
        log_stats: bool,
1271
        client_handshake_address: str | None = None,
1272
    ):
1273
1274
1275
1276
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1277
1278
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1279
        self.step_counter = 0
1280
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1281
        self.last_counts = (0, 0)
1282
1283
1284

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1285
1286
1287
1288
1289
1290
1291
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1292
            engine_index=dp_rank,
1293
        )
1294
1295
1296

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1297
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1298
        dp_size = vllm_config.parallel_config.data_parallel_size
1299
1300
1301
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1302
        assert local_dp_rank is not None
1303
1304
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1305
        self.dp_rank = dp_rank
1306
1307
1308
1309
1310
1311
1312
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1313
1314
1315
1316
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1317
1318
1319
1320
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1321
1322
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1323

1324
        super().add_request(request, request_wave)
1325

1326
1327
1328
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1329
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1330
1331
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1332
1333
                new_wave >= self.current_wave
            ):
1334
1335
                self.current_wave = new_wave
                if not self.engines_running:
1336
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1337
1338
1339
1340
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1341
    def _maybe_publish_request_counts(self):
1342
        if not self.publish_dp_lb_stats:
1343
1344
1345
1346
1347
1348
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1349
1350
1351
1352
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1353

1354
1355
1356
1357
1358
1359
1360
1361
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1362
1363
            # 2) Step the engine core.
            executed = self._process_engine_step()
1364
1365
            self._maybe_publish_request_counts()

1366
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1367
1368
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1369
1370
1371
                    # All engines are idle.
                    continue

1372
1373
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1374
1375
1376
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1377
            self.engines_running = self._has_global_unfinished_reqs(
1378
1379
                local_unfinished_reqs
            )
1380

1381
            if not self.engines_running:
1382
                if self.dp_rank == 0 or not self.has_coordinator:
1383
                    # Notify client that we are pausing the loop.
1384
1385
1386
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1387
1388
1389
1390
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1391
                    self.output_queue.put_nowait(
1392
1393
1394
1395
1396
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1397
                # Increment wave count and reset step counter.
1398
                self.current_wave += 1
1399
                self.step_counter = 0
1400
1401

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1402
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1403
1404
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1405
1406
            return True

1407
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1408

1409
    def reinitialize_distributed(
1410
1411
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1412
1413
1414
1415
1416
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1417
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1418
        if reconfig_request.new_data_parallel_rank != -1:
1419
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1420
        # local rank specifies device visibility, it should not be changed
1421
1422
1423
1424
1425
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1426
            reconfig_request.new_data_parallel_master_ip
1427
1428
        )
        parallel_config.data_parallel_master_port = (
1429
            reconfig_request.new_data_parallel_master_port
1430
        )
1431
1432
1433
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1434
        reconfig_request.new_data_parallel_master_port = (
1435
            parallel_config.data_parallel_master_port
1436
        )
1437
1438
1439
1440
1441
1442
1443
1444

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1445
1446
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1447
1448
1449
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1450
1451
1452
1453
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1454
1455
1456
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1457
1458
1459
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1460

Rui Qiao's avatar
Rui Qiao committed
1461

1462
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
1475
        vllm_config.parallel_config.data_parallel_index = dp_rank
1476
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1477

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1488
1489
1490
1491
1492
1493
1494
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1495
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1496

1497
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1498
        from vllm.platforms import current_platform
1499

1500
1501
1502
1503
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1504
1505
1506
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1507

1508
1509
1510
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1511
1512
1513
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1514
1515
1516
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1517
            os.environ[device_control_env_var] = value
1518
1519
1520
1521
1522
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1523
1524
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1525

Rui Qiao's avatar
Rui Qiao committed
1526
    @contextmanager
1527
1528
1529
1530
1531
1532
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1533
        client_handshake_address: str | None,
1534
    ):
Rui Qiao's avatar
Rui Qiao committed
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1557
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1558
1559
1560
1561
1562
1563
1564
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )